### Imports

In [1]:
# The code below uses few-shot learning to generalize a fine-tuned model on another task.

import os
import torch
from transformers import AutoImageProcessor, SwinForImageClassification, TrainingArguments, Trainer
import evaluate
from datasets import load_dataset, DatasetDict
import numpy as np
import random
from torch.utils.data import Dataset
from torchinfo import summary

print(os.getcwd())
os.chdir("..") # have to go up one directory, can also use os.chdir("..")
print(os.getcwd())

# CUDA check 
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#device = torch.device("cpu")
print(device)

  from .autonotebook import tqdm as notebook_tqdm


e:\Projects & Temp\GitHub\thesis\RQ3
e:\Projects & Temp\GitHub\thesis
cuda


### Define Model and Data

In [2]:
# Choose model
model_path = './sdxl-fine-tune'
# model_path = './sdxl-fine-tune-art'

processor = AutoImageProcessor.from_pretrained(model_path)
#classifier = pipeline("image-classification", model=model_path, device=0 if torch.cuda.is_available() else -1)

# Choose dataset
# dataset_path = 'archive/datasets/faces_512x512'
dataset_path = 'archive/datasets/art_512x512'

ds = load_dataset("imagefolder", data_dir=dataset_path)
print(ds)

DatasetDict({
    train: Dataset({
        features: ['image', 'label'],
        num_rows: 12800
    })
    validation: Dataset({
        features: ['image', 'label'],
        num_rows: 1600
    })
    test: Dataset({
        features: ['image', 'label'],
        num_rows: 1600
    })
})


### Custom Dataset Class 

In [None]:
# for the life of me i couldn't get the code to work with the normal huggingface dataset (despite the
# fact that the logic IS IDENTICAL to rq2_finetuning, but i digress) so here's a custom dataset class)

class CustomDataset(Dataset):
    def __init__(self, hf_dataset, processor, device):
        self.dataset = hf_dataset
        self.processor = processor
        self.device = device
        
    def __len__(self):
        return len(self.dataset)
    
    # manaully processes each image so pixel values are always returned ('image' is no longer needed for Trainer)
    def __getitem__(self, idx):
        item = self.dataset[idx]
        inputs = self.processor(images=item['image'], return_tensors="pt")
        return {
            'pixel_values': inputs['pixel_values'].squeeze(0),
            'labels': torch.tensor(item['label'])
        }

### Pre-Processing

In [4]:
# Turn dicts into tensors
def collate_fn(batch):
    pixel_values = torch.stack([x['pixel_values'] for x in batch])
    labels = torch.tensor([x['labels'] for x in batch])
    return {
        'pixel_values': pixel_values.to(device),
        'labels': labels.to(device)
    }

In [5]:
# Define metrics
acc_metric = evaluate.load('accuracy')
f1_metric = evaluate.load('f1')

def compute_metrics(p):
    acc = acc_metric.compute(predictions=np.argmax(p.predictions, axis=1), references=p.label_ids)
    f1 = f1_metric.compute(predictions=np.argmax(p.predictions, axis=1), references=p.label_ids)
    return {"Accuracy": acc["accuracy"], "F1": f1["f1"]}

### Create Few-Shot Set

In [None]:
def create_few_shot_set(dataset, set_size, seed): # set_size = 10 -> 5-shot learning (2 classes)
    random.seed(seed)

    newset_indices = []

    label0_indices = [i for i, label in enumerate(dataset['train']['label']) if label == 0]
    label1_indices = [i for i, label in enumerate(dataset['train']['label']) if label == 1]

    # get equal number of samples from each class
    newset0_indices = random.sample(label0_indices, set_size // 2)
    newset1_indices = random.sample(label1_indices, set_size // 2)
    newset_indices = newset0_indices + newset1_indices

    random.shuffle(newset_indices)

    few_shot_train = dataset['train'].select(newset_indices)
    val_set = dataset['validation'] # can use the full validation set

    return DatasetDict({
        'train': few_shot_train,
        'validation': val_set
    })

print(type(ds))

few_shot_ds = create_few_shot_set(ds, set_size=50, seed=42)

print(type(few_shot_ds))
print(few_shot_ds)

train_dataset = CustomDataset(few_shot_ds['train'], processor, device)
val_dataset = CustomDataset(few_shot_ds['validation'], processor, device)

# sample = train_dataset[0]
# print("Pixel values shape:", sample['pixel_values'].shape)
# print("Label:", sample['labels'])

<class 'datasets.dataset_dict.DatasetDict'>
<class 'datasets.dataset_dict.DatasetDict'>
DatasetDict({
    train: Dataset({
        features: ['image', 'label'],
        num_rows: 50
    })
    validation: Dataset({
        features: ['image', 'label'],
        num_rows: 1600
    })
})
Pixel values shape: torch.Size([3, 224, 224])
Label: tensor(0)


### Load Model

In [8]:
# Use current dataset to extract the labels (doesn't really matter which dataset we use)
labels = ds['train'].features['label'].names
print(labels[0:2])

# Load the pre-trained model
model = SwinForImageClassification.from_pretrained(
    model_path, 
    num_labels=len(labels),
    id2label={str(i): c for i, c in enumerate(labels)},
    label2id={c: str(i) for i, c in enumerate(labels)}
).to(device)

['0', '1']


### Prepare Model For Training

In [9]:
# Freezing earlier layers
summary(model, input_size=(1, 3, 224, 224)) 

for param in model.parameters():
    param.requires_grad = False # freeze all layers

for param in model.classifier.parameters():
    param.requires_grad = True # unfreeze the classifier layer

for param in model.swin.encoder.layers[-1].parameters():
    param.requires_grad = True # unfreeze the last layer of the encoder

In [10]:
# Training arguments
training_args = TrainingArguments(
    output_dir=model_path + "_few_shot",
    per_device_train_batch_size=4,
    per_device_eval_batch_size=8,
    num_train_epochs=3,
    learning_rate=5e-5,
    logging_dir="./logs",
    evaluation_strategy="epoch",
    save_strategy="no",
    logging_steps=1,
    load_best_model_at_end=False,
    report_to="none",
    dataloader_pin_memory=False  # Disable pin memory as we handle device placement
)



In [11]:
# Trainer 
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    data_collator=collate_fn,
    compute_metrics=compute_metrics,
    tokenizer=processor,
)

  trainer = Trainer(


### Training and Evaluation

In [13]:
train_results = trainer.train()
trainer.save_model()
trainer.log_metrics("train", train_results.metrics)
trainer.save_metrics("train", train_results.metrics)
trainer.save_state()

metrics = trainer.evaluate(val_dataset)
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)



 33%|███▎      | 13/39 [00:25<00:50,  1.96s/it]

[A

[A[A                                        
[A                                            
[A



{'loss': 0.4973, 'grad_norm': 5.3154616355896, 'learning_rate': 4.871794871794872e-05, 'epoch': 0.08}


  5%|▌         | 2/39 [00:00<00:08,  4.62it/s][A[A

[A[A                                        
[A                                            
[A



{'loss': 1.6069, 'grad_norm': 10.542373657226562, 'learning_rate': 4.7435897435897435e-05, 'epoch': 0.15}


  8%|▊         | 3/39 [00:00<00:07,  4.66it/s][A[A

[A[A                                        
[A                                            
[A



{'loss': 1.8867, 'grad_norm': 11.788561820983887, 'learning_rate': 4.615384615384616e-05, 'epoch': 0.23}


 10%|█         | 4/39 [00:00<00:07,  4.80it/s][A[A

[A[A                                        
[A                                            
[A
[A



{'loss': 3.3087, 'grad_norm': 28.391359329223633, 'learning_rate': 4.4871794871794874e-05, 'epoch': 0.31}


[A[A                                        
[A                                            
[A



{'loss': 1.7128, 'grad_norm': 13.749529838562012, 'learning_rate': 4.358974358974359e-05, 'epoch': 0.38}


 15%|█▌        | 6/39 [00:01<00:06,  5.02it/s][A[A

[A[A                                        
[A                                            
[A
[A



{'loss': 1.723, 'grad_norm': 9.954059600830078, 'learning_rate': 4.230769230769231e-05, 'epoch': 0.46}


[A[A                                        
[A                                            
[A
[A

{'loss': 0.986, 'grad_norm': 9.486756324768066, 'learning_rate': 4.1025641025641023e-05, 'epoch': 0.54}




[A[A                                        
[A                                            
[A
[A



{'loss': 1.6827, 'grad_norm': 12.045991897583008, 'learning_rate': 3.974358974358974e-05, 'epoch': 0.62}


[A[A                                        
[A                                            
[A
[A



{'loss': 1.6309, 'grad_norm': 19.067277908325195, 'learning_rate': 3.846153846153846e-05, 'epoch': 0.69}


[A[A                                         
[A                                            
[A



{'loss': 0.9635, 'grad_norm': 7.313977241516113, 'learning_rate': 3.717948717948718e-05, 'epoch': 0.77}


 28%|██▊       | 11/39 [00:02<00:05,  5.12it/s][A[A

[A[A                                         
[A                                            
[A
[A

[A[A                                         
[A                                            
[A

{'loss': 1.1003, 'grad_norm': 16.558387756347656, 'learning_rate': 3.58974358974359e-05, 'epoch': 0.85}
{'loss': 0.5654, 'grad_norm': 8.403083801269531, 'learning_rate': 3.461538461538462e-05, 'epoch': 0.92}



[A

[A[A                                         
[A                                            
[A

{'loss': 0.9548, 'grad_norm': 18.45084571838379, 'learning_rate': 3.3333333333333335e-05, 'epoch': 1.0}




[A[A                                         
[A                     
[A

{'eval_loss': 1.011102557182312, 'eval_Accuracy': 0.426875, 'eval_F1': 0.2669864108713029, 'eval_runtime': 67.0611, 'eval_samples_per_second': 23.859, 'eval_steps_per_second': 2.982, 'epoch': 1.0}



[A

[A[A                                         
[A



{'loss': 0.7071, 'grad_norm': 17.027502059936523, 'learning_rate': 3.205128205128206e-05, 'epoch': 1.08}


 38%|███▊      | 15/39 [01:10<05:44, 14.34s/it][A[A

[A[A                                         
[A



{'loss': 0.7829, 'grad_norm': 8.261507034301758, 'learning_rate': 3.0769230769230774e-05, 'epoch': 1.15}


 41%|████      | 16/39 [01:10<03:52, 10.09s/it][A[A

[A[A                                         
[A
[A

{'loss': 0.8882, 'grad_norm': 9.258219718933105, 'learning_rate': 2.948717948717949e-05, 'epoch': 1.23}




[A[A                                         
[A
[A

{'loss': 0.446, 'grad_norm': 7.603533744812012, 'learning_rate': 2.8205128205128207e-05, 'epoch': 1.31}




[A[A                                         
[A
[A



{'loss': 0.4676, 'grad_norm': 3.012171506881714, 'learning_rate': 2.6923076923076923e-05, 'epoch': 1.38}


[A[A                                         
[A

{'loss': 0.5315, 'grad_norm': 6.300508499145508, 'learning_rate': 2.564102564102564e-05, 'epoch': 1.46}



[A

[A[A                                         
[A

{'loss': 0.3477, 'grad_norm': 4.714243412017822, 'learning_rate': 2.435897435897436e-05, 'epoch': 1.54}



[A

[A[A                                         
[A
[A

[A[A                                         
[A

{'loss': 0.882, 'grad_norm': 6.261834144592285, 'learning_rate': 2.307692307692308e-05, 'epoch': 1.62}
{'loss': 0.6222, 'grad_norm': 6.355686187744141, 'learning_rate': 2.1794871794871795e-05, 'epoch': 1.69}



[A

[A[A                                         
[A
[A



{'loss': 0.5886, 'grad_norm': 4.098194599151611, 'learning_rate': 2.0512820512820512e-05, 'epoch': 1.77}


[A[A                                         
[A
[A

[A[A                                         
[A

{'loss': 0.4676, 'grad_norm': 4.84315299987793, 'learning_rate': 1.923076923076923e-05, 'epoch': 1.85}
{'loss': 0.6052, 'grad_norm': 9.499213218688965, 'learning_rate': 1.794871794871795e-05, 'epoch': 1.92}



[A

[A[A                                         
[A

{'loss': 0.3749, 'grad_norm': 6.581021785736084, 'learning_rate': 1.6666666666666667e-05, 'epoch': 2.0}


                                                 

100%|██████████| 200/200 [00:52<00:00,  4.35it/s]
                                                 

{'eval_loss': 0.5024647116661072, 'eval_Accuracy': 0.7675, 'eval_F1': 0.7750906892382103, 'eval_runtime': 52.9962, 'eval_samples_per_second': 30.191, 'eval_steps_per_second': 3.774, 'epoch': 2.0}



[A

[A[A                                         
[A
[A



{'loss': 0.3565, 'grad_norm': 5.309319972991943, 'learning_rate': 1.5384615384615387e-05, 'epoch': 2.08}


[A[A                                         
[A
[A



{'loss': 0.7706, 'grad_norm': 3.869166612625122, 'learning_rate': 1.4102564102564104e-05, 'epoch': 2.15}


[A[A                                         
[A
[A



{'loss': 0.4989, 'grad_norm': 3.958904981613159, 'learning_rate': 1.282051282051282e-05, 'epoch': 2.23}


[A[A                                         
[A
[A



{'loss': 0.1959, 'grad_norm': 2.2379508018493652, 'learning_rate': 1.153846153846154e-05, 'epoch': 2.31}


[A[A                                         
[A
[A



{'loss': 0.1847, 'grad_norm': 2.357346534729004, 'learning_rate': 1.0256410256410256e-05, 'epoch': 2.38}


[A[A                                         
[A
[A

[A[A                                         



{'loss': 0.2677, 'grad_norm': 3.655817985534668, 'learning_rate': 8.974358974358976e-06, 'epoch': 2.46}
{'loss': 0.1307, 'grad_norm': 1.6449494361877441, 'learning_rate': 7.692307692307694e-06, 'epoch': 2.54}


 85%|████████▍ | 33/39 [02:06<00:12,  2.08s/it][A[A
[A

[A[A                                         
[A



{'loss': 0.5342, 'grad_norm': 3.42049241065979, 'learning_rate': 6.41025641025641e-06, 'epoch': 2.62}


 90%|████████▉ | 35/39 [02:06<00:04,  1.11s/it][A[A

[A[A                                         
[A
[A

[A[A                                         

 92%|█████████▏| 36/39 [02:07<00:02,  1.20it/s]

{'loss': 0.6353, 'grad_norm': 3.7580440044403076, 'learning_rate': 5.128205128205128e-06, 'epoch': 2.69}
{'loss': 0.2656, 'grad_norm': 3.225121259689331, 'learning_rate': 3.846153846153847e-06, 'epoch': 2.77}


[A[A
[A

[A[A                                         
[A
[A

[A[A                                         
[A

{'loss': 0.4029, 'grad_norm': 2.981106758117676, 'learning_rate': 2.564102564102564e-06, 'epoch': 2.85}
{'loss': 0.109, 'grad_norm': 1.5377873182296753, 'learning_rate': 1.282051282051282e-06, 'epoch': 2.92}



[A

[A[A                                         
[A

{'loss': 0.4289, 'grad_norm': 13.870848655700684, 'learning_rate': 0.0, 'epoch': 3.0}


                                                 

100%|██████████| 200/200 [00:52<00:00,  4.10it/s]
                                                 

[A[A                                         
100%|██████████| 39/39 [03:00<00:00,  4.63s/it]
Non-default generation parameters: {'max_length': 128}


{'eval_loss': 0.44690197706222534, 'eval_Accuracy': 0.81125, 'eval_F1': 0.8221436984687868, 'eval_runtime': 53.0702, 'eval_samples_per_second': 30.149, 'eval_steps_per_second': 3.769, 'epoch': 3.0}
{'train_runtime': 180.5718, 'train_samples_per_second': 0.831, 'train_steps_per_second': 0.216, 'train_loss': 0.797732609586838, 'epoch': 3.0}
***** train metrics *****
  epoch                    =        3.0
  total_flos               = 10944747GF
  train_loss               =     0.7977
  train_runtime            = 0:03:00.57
  train_samples_per_second =      0.831
  train_steps_per_second   =      0.216


100%|██████████| 200/200 [00:52<00:00,  3.82it/s]

***** eval metrics *****
  epoch                   =        3.0
  eval_Accuracy           =     0.8113
  eval_F1                 =     0.8221
  eval_loss               =     0.4469
  eval_runtime            = 0:00:52.72
  eval_samples_per_second =     30.345
  eval_steps_per_second   =      3.793



