### Imports

In [None]:
# The code below uses few-shot learning to generalize a fine-tuned model on another task.

import os
import torch
from transformers import AutoImageProcessor, SwinForImageClassification, TrainingArguments, Trainer
import evaluate
from datasets import load_dataset
import numpy as np
from transformers import pipeline
from torch.utils.data import Subset
import random
from datasets import DatasetDict
from torchinfo import summary

print(os.getcwd())
os.chdir("..") # have to go up one directory, can also use os.chdir("..")
print(os.getcwd())

# CUDA check 
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#device = torch.device("cpu")
print(device)

e:\Projects & Temp\GitHub\thesis
e:\Projects & Temp\GitHub\thesis
cuda


### Define Model and Data

In [5]:
# Choose model
model_path = './sdxl-fine-tune'
# model_path = './sdxl-fine-tune-art'

processor = AutoImageProcessor.from_pretrained(model_path)
classifier = pipeline("image-classification", model=model_path, device=0 if torch.cuda.is_available() else -1)

# Choose dataset
# dataset_path = 'archive/datasets/faces_512x512'
dataset_path = 'archive/datasets/art_512x512'

ds = load_dataset("imagefolder", data_dir=dataset_path)
print(ds)

Device set to use cuda:0


DatasetDict({
    train: Dataset({
        features: ['image', 'label'],
        num_rows: 12800
    })
    validation: Dataset({
        features: ['image', 'label'],
        num_rows: 1600
    })
    test: Dataset({
        features: ['image', 'label'],
        num_rows: 1600
    })
})


### Pre-Processing

In [6]:
# Transform images to model input
def transform(image_batch):
    inputs = processor(images=image_batch['image'], return_tensors="pt")
    inputs['labels'] = torch.tensor(image_batch['label']).to(device)  # Ensure labels are tensors
    inputs['pixel_values'] = inputs['pixel_values'].to(device) 
    return inputs

In [7]:
# Turn dicts into tensors
def collate_fn(batch):
    return {
        'pixel_values': torch.stack([x['pixel_values'] for x in batch]).to(device),
        'labels': torch.tensor([x['labels'] for x in batch]).to(device)
    }

In [8]:
# Define metrics
acc_metric = evaluate.load('accuracy')
f1_metric = evaluate.load('f1')

def compute_metrics(p):
    acc = acc_metric.compute(predictions=np.argmax(p.predictions, axis=1), references=p.label_ids)
    f1 = f1_metric.compute(predictions=np.argmax(p.predictions, axis=1), references=p.label_ids)
    return {"Accuracy": acc["accuracy"], "F1": f1["f1"]}

### Create Few-Shot Set

In [41]:
def create_few_shot_set(dataset, set_size, seed): # set_size = 10 -> 5-shot learning (2 classes)
    random.seed(seed)

    newset_indices = []

    label0_indices = [i for i, label in enumerate(dataset['train']['label']) if label == 0]
    label1_indices = [i for i, label in enumerate(dataset['train']['label']) if label == 1]

    # get equal number of samples from each class
    newset0_indices = random.sample(label0_indices, set_size // 2)
    newset1_indices = random.sample(label1_indices, set_size // 2)
    newset_indices = newset0_indices + newset1_indices

    random.shuffle(newset_indices)

    few_shot_train = dataset['train'].select(newset_indices)
    val_set = dataset['validation'] # can use the full validation set

    return DatasetDict({
        'train': few_shot_train,
        'validation': val_set
    })

print(type(ds))

few_shot_ds = create_few_shot_set(ds, set_size=50, seed=42)

print(type(few_shot_ds))
print(few_shot_ds)

fs_transformed = few_shot_ds.with_transform(transform)

print(type(fs_transformed))

<class 'datasets.dataset_dict.DatasetDict'>
<class 'datasets.dataset_dict.DatasetDict'>
DatasetDict({
    train: Dataset({
        features: ['image', 'label'],
        num_rows: 50
    })
    validation: Dataset({
        features: ['image', 'label'],
        num_rows: 1600
    })
})
<class 'datasets.dataset_dict.DatasetDict'>


### Load Model

In [42]:
# Use current dataset to extract the labels (doesn't really matter which dataset we use)
labels = fs_transformed['train'].features['label'].names
print(labels[0:2])

# Load the pre-trained model
model = SwinForImageClassification.from_pretrained(
    model_path, 
    num_labels=len(labels),
    id2label={str(i): c for i, c in enumerate(labels)},
    label2id={c: str(i) for i, c in enumerate(labels)}
).to(device)

['0', '1']


### Prepare Model For Training

In [43]:
# Freezing earlier layers
summary(model, input_size=(1, 3, 224, 224)) 

for param in model.parameters():
    param.requires_grad = False # freeze all layers

for param in model.classifier.parameters():
    param.requires_grad = True # unfreeze the classifier layer

# for param in model.swin.encoder.layer[-1].parameters():
#     param.requires_grad = True # unfreeze the last layer of the encoder

In [44]:
# Training arguments
training_args = TrainingArguments(
    output_dir=model_path + "_few_shot",
    per_device_train_batch_size=4,
    per_device_eval_batch_size=8,
    num_train_epochs=3,
    learning_rate=5e-5, # same LR as in fine-tuning but this time with scheduler
    logging_dir="./logs",
    evaluation_strategy="epoch",
    save_strategy="no",
    logging_steps=1,
    load_best_model_at_end=False,
    report_to="none"
)



In [45]:
# Trainer 
trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=collate_fn,
    compute_metrics=compute_metrics,
    train_dataset=fs_transformed['train'],
    eval_dataset=fs_transformed['validation'],
    tokenizer=processor,
)

  trainer = Trainer(


### Training and Evaluation

In [46]:
train_results = trainer.train()
trainer.save_model()
trainer.log_metrics("train", train_results.metrics)
trainer.save_metrics("train", train_results.metrics)
trainer.save_state()

metrics = trainer.evaluate(fs_transformed['validation'])
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)


  0%|          | 0/39 [05:16<?, ?it/s]
  0%|          | 0/39 [00:00<?, ?it/s]

KeyError: 'image'