### Imports

In [1]:
# The code below uses few-shot learning to generalize a fine-tuned model on another task.

import os
import torch
from transformers import AutoImageProcessor, SwinForImageClassification, TrainingArguments, Trainer
import evaluate
from datasets import load_dataset, DatasetDict
import numpy as np
import random
from torch.utils.data import Dataset
from torchinfo import summary

print(os.getcwd())
os.chdir("..") # have to go up one directory, can also use os.chdir("..")
print(os.getcwd())

# CUDA check 
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#device = torch.device("cpu")
print(device)

  from .autonotebook import tqdm as notebook_tqdm


e:\Projects & Temp\GitHub\thesis\RQ3
e:\Projects & Temp\GitHub\thesis
cuda


### Define Model and Data

In [2]:
# Choose model
model_path = './sdxl-fine-tune'
# model_path = './sdxl-fine-tune-art'

processor = AutoImageProcessor.from_pretrained(model_path)
#classifier = pipeline("image-classification", model=model_path, device=0 if torch.cuda.is_available() else -1)

# Choose dataset
# dataset_path = 'archive/datasets/faces_512x512'
dataset_path = 'archive/datasets/art_512x512'

ds = load_dataset("imagefolder", data_dir=dataset_path)
print(ds)

DatasetDict({
    train: Dataset({
        features: ['image', 'label'],
        num_rows: 12800
    })
    validation: Dataset({
        features: ['image', 'label'],
        num_rows: 1600
    })
    test: Dataset({
        features: ['image', 'label'],
        num_rows: 1600
    })
})


### Custom Dataset Class 

In [3]:
# for the life of me i couldn't get the code to work with the normal huggingface dataset (despite the
# fact that the logic IS IDENTICAL to rq2_finetuning, but i digress) so here's a custom dataset class)

class CustomDataset(Dataset):
    def __init__(self, hf_dataset, processor, device):
        self.dataset = hf_dataset
        self.processor = processor
        self.device = device
        
    def __len__(self):
        return len(self.dataset)
    
    # manaully processes each image so pixel values are always returned ('image' is no longer needed for Trainer)
    def __getitem__(self, idx):
        item = self.dataset[idx]
        inputs = self.processor(images=item['image'], return_tensors="pt")
        return {
            'pixel_values': inputs['pixel_values'].squeeze(0),
            'labels': torch.tensor(item['label'])
        }

### Pre-Processing

In [4]:
# Turn dicts into tensors
def collate_fn(batch):
    pixel_values = torch.stack([x['pixel_values'] for x in batch])
    labels = torch.tensor([x['labels'] for x in batch])
    return {
        'pixel_values': pixel_values.to(device),
        'labels': labels.to(device)
    }

In [5]:
# Define metrics
acc_metric = evaluate.load('accuracy')
f1_metric = evaluate.load('f1')

def compute_metrics(p):
    acc = acc_metric.compute(predictions=np.argmax(p.predictions, axis=1), references=p.label_ids)
    f1 = f1_metric.compute(predictions=np.argmax(p.predictions, axis=1), references=p.label_ids)
    return {"Accuracy": acc["accuracy"], "F1": f1["f1"]}

### Create Few-Shot Set

In [6]:
def create_few_shot_set(dataset, set_size, seed): # set_size = 10 -> 5-shot learning (2 classes)
    random.seed(seed)

    newset_indices = []

    label0_indices = [i for i, label in enumerate(dataset['train']['label']) if label == 0]
    label1_indices = [i for i, label in enumerate(dataset['train']['label']) if label == 1]

    # get equal number of samples from each class
    newset0_indices = random.sample(label0_indices, set_size // 2)
    newset1_indices = random.sample(label1_indices, set_size // 2)
    newset_indices = newset0_indices + newset1_indices

    random.shuffle(newset_indices)

    few_shot_train = dataset['train'].select(newset_indices)
    val_set = dataset['validation'] # can use the full validation set

    return DatasetDict({
        'train': few_shot_train,
        'validation': val_set
    })

print(type(ds))

few_shot_ds = create_few_shot_set(ds, set_size=50, seed=42)

print(type(few_shot_ds))
print(few_shot_ds)

train_dataset = CustomDataset(few_shot_ds['train'], processor, device)
val_dataset = CustomDataset(few_shot_ds['validation'], processor, device)

# sample = train_dataset[0]
# print("Pixel values shape:", sample['pixel_values'].shape)
# print("Label:", sample['labels'])

<class 'datasets.dataset_dict.DatasetDict'>
<class 'datasets.dataset_dict.DatasetDict'>
DatasetDict({
    train: Dataset({
        features: ['image', 'label'],
        num_rows: 50
    })
    validation: Dataset({
        features: ['image', 'label'],
        num_rows: 1600
    })
})


### Load Model

In [7]:
# Use current dataset to extract the labels (doesn't really matter which dataset we use)
labels = ds['train'].features['label'].names
print(labels[0:2])

# Load the pre-trained model
model = SwinForImageClassification.from_pretrained(
    model_path, 
    num_labels=len(labels),
    id2label={str(i): c for i, c in enumerate(labels)},
    label2id={c: str(i) for i, c in enumerate(labels)}
).to(device)

['0', '1']


### Prepare Model For Training

In [16]:
# Freezing earlier layers
summary(model, input_size=(1, 3, 224, 224)) 

for param in model.parameters():
    param.requires_grad = False # freeze all layers

for param in model.classifier.parameters():
    param.requires_grad = True # unfreeze the classifier layer

for param in model.swin.encoder.layers[-1].parameters():
    param.requires_grad = True # unfreeze the last layer of the encoder

In [None]:
# Training arguments
training_args = TrainingArguments(
    output_dir=model_path + "_few_shot_fr",
    per_device_train_batch_size=4,
    per_device_eval_batch_size=8,
    gradient_accumulation_steps=4  # mimics having a larger batch size and makes training more stable
    num_train_epochs=3,
    learning_rate=1e-5, # need low learning rate to avoid forgetting of original dataset stuff
    logging_dir="./logs",
    evaluation_strategy="epoch",
    save_strategy="no",
    logging_steps=1,
    load_best_model_at_end=False,
    report_to="none",
    dataloader_pin_memory=False # causes runtime error otherwise
)



In [18]:
# Trainer 
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    data_collator=collate_fn,
    compute_metrics=compute_metrics,
    tokenizer=processor,
)

  trainer = Trainer(


### Training and Evaluation

In [19]:
train_results = trainer.train()
trainer.save_model()
trainer.log_metrics("train", train_results.metrics)
trainer.save_metrics("train", train_results.metrics)
trainer.save_state()

metrics = trainer.evaluate(val_dataset)
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)


  5%|▌         | 2/39 [00:00<00:08,  4.21it/s]

{'loss': 0.3442, 'grad_norm': 3.550860643386841, 'learning_rate': 4.871794871794872e-05, 'epoch': 0.08}


  5%|▌         | 2/39 [00:00<00:08,  4.21it/s]

{'loss': 0.7631, 'grad_norm': 5.930713176727295, 'learning_rate': 4.7435897435897435e-05, 'epoch': 0.15}


  8%|▊         | 3/39 [00:00<00:08,  4.47it/s]

{'loss': 0.8946, 'grad_norm': 6.941886901855469, 'learning_rate': 4.615384615384616e-05, 'epoch': 0.23}


 13%|█▎        | 5/39 [00:01<00:06,  4.90it/s]

{'loss': 0.7349, 'grad_norm': 16.359725952148438, 'learning_rate': 4.4871794871794874e-05, 'epoch': 0.31}


 15%|█▌        | 6/39 [00:01<00:06,  5.07it/s]

{'loss': 0.5275, 'grad_norm': 7.849201679229736, 'learning_rate': 4.358974358974359e-05, 'epoch': 0.38}


 18%|█▊        | 7/39 [00:01<00:06,  5.00it/s]

{'loss': 0.8749, 'grad_norm': 6.876099586486816, 'learning_rate': 4.230769230769231e-05, 'epoch': 0.46}


 21%|██        | 8/39 [00:01<00:06,  5.13it/s]

{'loss': 0.7721, 'grad_norm': 10.175084114074707, 'learning_rate': 4.1025641025641023e-05, 'epoch': 0.54}


 23%|██▎       | 9/39 [00:01<00:05,  5.21it/s]

{'loss': 1.0275, 'grad_norm': 8.328709602355957, 'learning_rate': 3.974358974358974e-05, 'epoch': 0.62}


 26%|██▌       | 10/39 [00:02<00:05,  5.29it/s]

{'loss': 0.3776, 'grad_norm': 4.696666240692139, 'learning_rate': 3.846153846153846e-05, 'epoch': 0.69}
{'loss': 0.638, 'grad_norm': 6.699556350708008, 'learning_rate': 3.717948717948718e-05, 'epoch': 0.77}


 31%|███       | 12/39 [00:02<00:04,  5.55it/s]

{'loss': 0.1641, 'grad_norm': 2.3011765480041504, 'learning_rate': 3.58974358974359e-05, 'epoch': 0.85}
{'loss': 0.183, 'grad_norm': 3.233072280883789, 'learning_rate': 3.461538461538462e-05, 'epoch': 0.92}


 33%|███▎      | 13/39 [00:02<00:04,  6.39it/s]

{'loss': 0.2037, 'grad_norm': 2.9990427494049072, 'learning_rate': 3.3333333333333335e-05, 'epoch': 1.0}



 33%|███▎      | 13/39 [00:53<00:04,  6.39it/s]  

{'eval_loss': 0.4592979848384857, 'eval_Accuracy': 0.790625, 'eval_F1': 0.7970926711084191, 'eval_runtime': 50.784, 'eval_samples_per_second': 31.506, 'eval_steps_per_second': 3.938, 'epoch': 1.0}


                                               

{'loss': 0.2341, 'grad_norm': 6.675081253051758, 'learning_rate': 3.205128205128206e-05, 'epoch': 1.08}


 41%|████      | 16/39 [00:53<02:56,  7.67s/it]

{'loss': 0.2389, 'grad_norm': 3.3428657054901123, 'learning_rate': 3.0769230769230774e-05, 'epoch': 1.15}


 44%|████▎     | 17/39 [00:54<01:59,  5.42s/it]

{'loss': 0.2607, 'grad_norm': 3.7554268836975098, 'learning_rate': 2.948717948717949e-05, 'epoch': 1.23}


 46%|████▌     | 18/39 [00:54<01:20,  3.85s/it]

{'loss': 0.5435, 'grad_norm': 7.24218225479126, 'learning_rate': 2.8205128205128207e-05, 'epoch': 1.31}


 49%|████▊     | 19/39 [00:54<00:55,  2.75s/it]

{'loss': 0.2198, 'grad_norm': 2.9780044555664062, 'learning_rate': 2.6923076923076923e-05, 'epoch': 1.38}


 51%|█████▏    | 20/39 [00:54<00:37,  1.99s/it]

{'loss': 0.121, 'grad_norm': 1.9317845106124878, 'learning_rate': 2.564102564102564e-05, 'epoch': 1.46}


 54%|█████▍    | 21/39 [00:54<00:26,  1.45s/it]

{'loss': 0.183, 'grad_norm': 2.8568222522735596, 'learning_rate': 2.435897435897436e-05, 'epoch': 1.54}


 56%|█████▋    | 22/39 [00:55<00:18,  1.07s/it]

{'loss': 0.5833, 'grad_norm': 6.3516926765441895, 'learning_rate': 2.307692307692308e-05, 'epoch': 1.62}


 56%|█████▋    | 22/39 [00:55<00:18,  1.07s/it]

{'loss': 0.1547, 'grad_norm': 2.4694035053253174, 'learning_rate': 2.1794871794871795e-05, 'epoch': 1.69}


 62%|██████▏   | 24/39 [00:55<00:09,  1.60it/s]

{'loss': 0.61, 'grad_norm': 6.538387298583984, 'learning_rate': 2.0512820512820512e-05, 'epoch': 1.77}


 64%|██████▍   | 25/39 [00:55<00:06,  2.05it/s]

{'loss': 0.4445, 'grad_norm': 4.794601917266846, 'learning_rate': 1.923076923076923e-05, 'epoch': 1.85}
{'loss': 0.0864, 'grad_norm': 1.8787509202957153, 'learning_rate': 1.794871794871795e-05, 'epoch': 1.92}


 67%|██████▋   | 26/39 [00:55<00:04,  2.65it/s]

{'loss': 0.7786, 'grad_norm': 11.458024978637695, 'learning_rate': 1.6666666666666667e-05, 'epoch': 2.0}



 67%|██████▋   | 26/39 [01:46<00:04,  2.65it/s]  

{'eval_loss': 0.3526747524738312, 'eval_Accuracy': 0.85125, 'eval_F1': 0.8662921348314606, 'eval_runtime': 50.8009, 'eval_samples_per_second': 31.495, 'eval_steps_per_second': 3.937, 'epoch': 2.0}


 72%|███████▏  | 28/39 [01:46<02:00, 10.96s/it]

{'loss': 0.1485, 'grad_norm': 4.097623348236084, 'learning_rate': 1.5384615384615387e-05, 'epoch': 2.08}


 74%|███████▍  | 29/39 [01:47<01:17,  7.73s/it]

{'loss': 0.5753, 'grad_norm': 6.585324764251709, 'learning_rate': 1.4102564102564104e-05, 'epoch': 2.15}


 77%|███████▋  | 30/39 [01:47<00:49,  5.47s/it]

{'loss': 0.1304, 'grad_norm': 2.0621750354766846, 'learning_rate': 1.282051282051282e-05, 'epoch': 2.23}


 79%|███████▉  | 31/39 [01:47<00:31,  3.89s/it]

{'loss': 0.0484, 'grad_norm': 1.0147490501403809, 'learning_rate': 1.153846153846154e-05, 'epoch': 2.31}


 82%|████████▏ | 32/39 [01:47<00:19,  2.78s/it]

{'loss': 0.049, 'grad_norm': 0.830826461315155, 'learning_rate': 1.0256410256410256e-05, 'epoch': 2.38}


 85%|████████▍ | 33/39 [01:47<00:12,  2.00s/it]

{'loss': 0.0502, 'grad_norm': 1.0126272439956665, 'learning_rate': 8.974358974358976e-06, 'epoch': 2.46}


 87%|████████▋ | 34/39 [01:48<00:07,  1.46s/it]

{'loss': 0.7022, 'grad_norm': 7.166563034057617, 'learning_rate': 7.692307692307694e-06, 'epoch': 2.54}


 90%|████████▉ | 35/39 [01:48<00:04,  1.08s/it]

{'loss': 0.1892, 'grad_norm': 2.381730318069458, 'learning_rate': 6.41025641025641e-06, 'epoch': 2.62}


 92%|█████████▏| 36/39 [01:48<00:02,  1.22it/s]

{'loss': 0.2847, 'grad_norm': 2.7699930667877197, 'learning_rate': 5.128205128205128e-06, 'epoch': 2.69}


 95%|█████████▍| 37/39 [01:48<00:01,  1.59it/s]

{'loss': 0.0624, 'grad_norm': 1.2255979776382446, 'learning_rate': 3.846153846153847e-06, 'epoch': 2.77}


 97%|█████████▋| 38/39 [01:48<00:00,  2.03it/s]

{'loss': 0.2923, 'grad_norm': 4.88690710067749, 'learning_rate': 2.564102564102564e-06, 'epoch': 2.85}
{'loss': 0.043, 'grad_norm': 0.8367530107498169, 'learning_rate': 1.282051282051282e-06, 'epoch': 2.92}


100%|██████████| 39/39 [01:48<00:00,  2.66it/s]

{'loss': 0.2835, 'grad_norm': 10.385313034057617, 'learning_rate': 0.0, 'epoch': 3.0}



100%|██████████| 39/39 [02:39<00:00,  4.09s/it]  
Non-default generation parameters: {'max_length': 128}


{'eval_loss': 0.3397517800331116, 'eval_Accuracy': 0.855, 'eval_F1': 0.8681818181818182, 'eval_runtime': 50.6064, 'eval_samples_per_second': 31.617, 'eval_steps_per_second': 3.952, 'epoch': 3.0}
{'train_runtime': 159.5938, 'train_samples_per_second': 0.94, 'train_steps_per_second': 0.244, 'train_loss': 0.380073163849421, 'epoch': 3.0}
***** train metrics *****
  epoch                    =        3.0
  total_flos               = 10944747GF
  train_loss               =     0.3801
  train_runtime            = 0:02:39.59
  train_samples_per_second =       0.94
  train_steps_per_second   =      0.244


100%|██████████| 200/200 [00:50<00:00,  3.99it/s]

***** eval metrics *****
  epoch                   =        3.0
  eval_Accuracy           =      0.855
  eval_F1                 =     0.8682
  eval_loss               =     0.3398
  eval_runtime            = 0:00:50.52
  eval_samples_per_second =     31.668
  eval_steps_per_second   =      3.959



