### Imports

In [2]:
# The code below uses few-shot learning to generalize a fine-tuned model on another task.

import os
import torch
from transformers import AutoImageProcessor, SwinForImageClassification, TrainingArguments, Trainer
import evaluate
from datasets import load_dataset, DatasetDict
import numpy as np
import random
from torch.utils.data import Dataset
from torchinfo import summary

print(os.getcwd())
os.chdir("..") # have to go up one directory, can also use os.chdir("..")
print(os.getcwd())

# CUDA check 
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#device = torch.device("cpu")
print(device)

  from .autonotebook import tqdm as notebook_tqdm


c:\Users\metet\OneDrive\Documents\GitHub\thesis\RQ3
c:\Users\metet\OneDrive\Documents\GitHub\thesis
cuda


### Define Model and Data

In [4]:
# Choose model
model_path = './sdxl-fine-tune'
# model_path = './sdxl-fine-tune-art'

processor = AutoImageProcessor.from_pretrained(model_path)
#classifier = pipeline("image-classification", model=model_path, device=0 if torch.cuda.is_available() else -1)

# Choose dataset
# dataset_path = 'archive/datasets/faces_512x512'
dataset_path = 'archive/datasets/art_512x512'

ds = load_dataset("imagefolder", data_dir=dataset_path)
print(ds)

DatasetDict({
    train: Dataset({
        features: ['image', 'label'],
        num_rows: 12800
    })
    validation: Dataset({
        features: ['image', 'label'],
        num_rows: 1600
    })
    test: Dataset({
        features: ['image', 'label'],
        num_rows: 1600
    })
})


### Custom Dataset Class 

In [5]:
# for the life of me i couldn't get the code to work with the normal huggingface dataset (despite the
# fact that the logic IS IDENTICAL to rq2_finetuning, but i digress) so here's a custom dataset class)

class CustomDataset(Dataset):
    def __init__(self, hf_dataset, processor, device):
        self.dataset = hf_dataset
        self.processor = processor
        self.device = device
        
    def __len__(self):
        return len(self.dataset)
    
    # manaully processes each image so pixel values are always returned ('image' is no longer needed for Trainer)
    def __getitem__(self, idx):
        item = self.dataset[idx]
        inputs = self.processor(images=item['image'], return_tensors="pt")
        return {
            'pixel_values': inputs['pixel_values'].squeeze(0),
            'labels': torch.tensor(item['label'])
        }

### Pre-Processing

In [6]:
# Turn dicts into tensors
def collate_fn(batch):
    pixel_values = torch.stack([x['pixel_values'] for x in batch])
    labels = torch.tensor([x['labels'] for x in batch])
    return {
        'pixel_values': pixel_values.to(device),
        'labels': labels.to(device)
    }

In [7]:
# Define metrics
acc_metric = evaluate.load('accuracy')
f1_metric = evaluate.load('f1')

def compute_metrics(p):
    acc = acc_metric.compute(predictions=np.argmax(p.predictions, axis=1), references=p.label_ids)
    f1 = f1_metric.compute(predictions=np.argmax(p.predictions, axis=1), references=p.label_ids)
    return {"Accuracy": acc["accuracy"], "F1": f1["f1"]}

### Create Few-Shot Set

In [135]:
def create_few_shot_set(dataset, set_size, seed): # set_size = 10 -> 5-shot learning (2 classes)
    random.seed(seed)

    newset_indices = []

    label0_indices = [i for i, label in enumerate(dataset['train']['label']) if label == 0]
    label1_indices = [i for i, label in enumerate(dataset['train']['label']) if label == 1]

    # get equal number of samples from each class
    newset0_indices = random.sample(label0_indices, set_size // 2)
    newset1_indices = random.sample(label1_indices, set_size // 2)
    newset_indices = newset0_indices + newset1_indices

    random.shuffle(newset_indices)

    few_shot_train = dataset['train'].select(newset_indices)
    val_set = dataset['validation'] # can use the full validation set

    return DatasetDict({
        'train': few_shot_train,
        'validation': val_set
    })

print(type(ds))

few_shot_ds = create_few_shot_set(ds, set_size=50, seed=44)

print(type(few_shot_ds))
print(few_shot_ds)

train_dataset = CustomDataset(few_shot_ds['train'], processor, device)
val_dataset = CustomDataset(few_shot_ds['validation'], processor, device)

# sample = train_dataset[0]
# print("Pixel values shape:", sample['pixel_values'].shape)
# print("Label:", sample['labels'])

<class 'datasets.dataset_dict.DatasetDict'>
<class 'datasets.dataset_dict.DatasetDict'>
DatasetDict({
    train: Dataset({
        features: ['image', 'label'],
        num_rows: 50
    })
    validation: Dataset({
        features: ['image', 'label'],
        num_rows: 1600
    })
})


### Load Model

In [136]:
# Use current dataset to extract the labels (doesn't really matter which dataset we use)
labels = ds['train'].features['label'].names
print(labels[0:2])

# Load the pre-trained model
model = SwinForImageClassification.from_pretrained(
    model_path, 
    num_labels=len(labels),
    id2label={str(i): c for i, c in enumerate(labels)},
    label2id={c: str(i) for i, c in enumerate(labels)}
).to(device)

['0', '1']


### Prepare Model For Training

In [137]:
# Freezing earlier layers
print(summary(model, input_size=(1, 3, 224, 224)))

for param in model.parameters():
    param.requires_grad = False # freeze all layers

for param in model.classifier.parameters():
    param.requires_grad = True # unfreeze the classifier layer

for param in model.swin.encoder.layers[-1].parameters():
    param.requires_grad = True # unfreeze the last layer of the encoder

for name, param in model.named_parameters():
    if "layernorm" in name.lower():
        param.requires_grad = True # unfreeze layernorm layers 


Layer (type:depth-idx)                                            Output Shape              Param #
SwinForImageClassification                                        [1, 2]                    --
├─SwinModel: 1-1                                                  [1, 1024]                 --
│    └─SwinEmbeddings: 2-1                                        [1, 3136, 128]            --
│    │    └─SwinPatchEmbeddings: 3-1                              [1, 3136, 128]            6,272
│    │    └─LayerNorm: 3-2                                        [1, 3136, 128]            256
│    │    └─Dropout: 3-3                                          [1, 3136, 128]            --
│    └─SwinEncoder: 2-2                                           [1, 49, 1024]             --
│    │    └─ModuleList: 3-4                                       --                        86,734,648
│    └─LayerNorm: 2-3                                             [1, 49, 1024]             2,048
│    └─AdaptiveAvgPool1d: 2-4 

In [138]:
# Training arguments
training_args = TrainingArguments(
    output_dir=model_path + "_few_shot",
    per_device_train_batch_size=4,
    per_device_eval_batch_size=8,
    num_train_epochs=10, # increase num of epochs for longer training with smaller LR
    learning_rate=5e-6, # need low learning rate to avoid forgetting of original dataset stuff
    weight_decay=1e-4, # adds regularization to all layers, can help generalize better (?)
    logging_dir="./logs",
    evaluation_strategy="epoch",
    save_strategy="epoch",
    logging_steps=1,
    load_best_model_at_end=True,
    report_to="none",
    dataloader_pin_memory=False, # causes runtime error otherwise
)



In [139]:
from transformers import EarlyStoppingCallback

# Trainer 
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    data_collator=collate_fn,
    compute_metrics=compute_metrics,
    tokenizer=processor,
    #callbacks=[EarlyStoppingCallback(early_stopping_patience=15)],
)

  trainer = Trainer(


### Training and Evaluation

In [140]:
train_results = trainer.train()
trainer.save_model()
trainer.log_metrics("train", train_results.metrics)
trainer.save_metrics("train", train_results.metrics)
trainer.save_state()

metrics = trainer.evaluate(val_dataset)
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)



[A
[A

[A[A                                         
                                                 
 80%|████████  | 104/130 [17:15<00:08,  3.02it/s]
[A

[A[A                                         
                                                 
 80%|████████  | 104/130 [17:15<00:08,  3.02it/s]

{'loss': 5.5104, 'grad_norm': 36.6348991394043, 'learning_rate': 4.961538461538462e-06, 'epoch': 0.08}
{'loss': 4.9929, 'grad_norm': 21.52077865600586, 'learning_rate': 4.923076923076924e-06, 'epoch': 0.15}



[A

[A[A                                         
                                                 
 80%|████████  | 104/130 [17:16<00:08,  3.02it/s]
[A

[A[A                                         
                                                 
 80%|████████  | 104/130 [17:16<00:08,  3.02it/s]

{'loss': 4.1023, 'grad_norm': 17.66706657409668, 'learning_rate': 4.884615384615385e-06, 'epoch': 0.23}
{'loss': 6.1254, 'grad_norm': 27.463912963867188, 'learning_rate': 4.8461538461538465e-06, 'epoch': 0.31}



[A

[A[A                                         
                                                 
 80%|████████  | 104/130 [17:16<00:08,  3.02it/s]
[A

[A[A                                         
                                                 
 80%|████████  | 104/130 [17:16<00:08,  3.02it/s]

{'loss': 4.6446, 'grad_norm': 22.60356330871582, 'learning_rate': 4.807692307692308e-06, 'epoch': 0.38}
{'loss': 5.3485, 'grad_norm': 27.53818130493164, 'learning_rate': 4.76923076923077e-06, 'epoch': 0.46}



[A

[A[A                                         
                                                 
 80%|████████  | 104/130 [17:16<00:08,  3.02it/s]
[A

[A[A                                         
                                                 
 80%|████████  | 104/130 [17:16<00:08,  3.02it/s]

{'loss': 3.8789, 'grad_norm': 20.627696990966797, 'learning_rate': 4.730769230769231e-06, 'epoch': 0.54}
{'loss': 2.6522, 'grad_norm': 35.97277069091797, 'learning_rate': 4.692307692307693e-06, 'epoch': 0.62}



[A

[A[A                                         
                                                 
 80%|████████  | 104/130 [17:17<00:08,  3.02it/s]
[A

[A[A                                          
                                                 
 80%|████████  | 104/130 [17:17<00:08,  3.02it/s]

{'loss': 4.4418, 'grad_norm': 19.707015991210938, 'learning_rate': 4.653846153846154e-06, 'epoch': 0.69}
{'loss': 7.0472, 'grad_norm': 32.3298454284668, 'learning_rate': 4.615384615384616e-06, 'epoch': 0.77}



[A

[A[A                                          
                                                 
 80%|████████  | 104/130 [17:17<00:08,  3.02it/s]
[A

[A[A                                          
                                                 
 80%|████████  | 104/130 [17:17<00:08,  3.02it/s]

{'loss': 3.255, 'grad_norm': 23.454002380371094, 'learning_rate': 4.5769230769230775e-06, 'epoch': 0.85}
{'loss': 5.2227, 'grad_norm': 28.517230987548828, 'learning_rate': 4.538461538461539e-06, 'epoch': 0.92}




[A[A                                          
                                                 
 80%|████████  | 104/130 [17:17<00:08,  3.02it/s]

{'loss': 0.6576, 'grad_norm': 20.002845764160156, 'learning_rate': 4.5e-06, 'epoch': 1.0}




[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

[A[A

KeyboardInterrupt: 