### Imports

In [1]:
# The code below uses few-shot learning to generalize a fine-tuned model on another task.

import os
import torch
from transformers import AutoImageProcessor, SwinForImageClassification, TrainingArguments, Trainer
import evaluate
from datasets import load_dataset, DatasetDict
import numpy as np
import random
from torch.utils.data import Dataset
from torchinfo import summary

print(os.getcwd())
os.chdir("..") # have to go up one directory, can also use os.chdir("..")
print(os.getcwd())

# CUDA check 
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#device = torch.device("cpu")
print(device)

  from .autonotebook import tqdm as notebook_tqdm


e:\Projects & Temp\GitHub\thesis\RQ3
e:\Projects & Temp\GitHub\thesis
cuda


### Define Model and Data

In [2]:
# Choose model
model_path = './sdxl-fine-tune'
# model_path = './sdxl-fine-tune-art'

processor = AutoImageProcessor.from_pretrained(model_path)
#classifier = pipeline("image-classification", model=model_path, device=0 if torch.cuda.is_available() else -1)

# Choose dataset
# dataset_path = 'archive/datasets/faces_512x512'
dataset_path = 'archive/datasets/art_512x512'

ds = load_dataset("imagefolder", data_dir=dataset_path)
print(ds)

DatasetDict({
    train: Dataset({
        features: ['image', 'label'],
        num_rows: 12800
    })
    validation: Dataset({
        features: ['image', 'label'],
        num_rows: 1600
    })
    test: Dataset({
        features: ['image', 'label'],
        num_rows: 1600
    })
})


### Custom Dataset Class 

In [3]:
# for the life of me i couldn't get the code to work with the normal huggingface dataset (despite the
# fact that the logic IS IDENTICAL to rq2_finetuning, but i digress) so here's a custom dataset class)

class CustomDataset(Dataset):
    def __init__(self, hf_dataset, processor, device):
        self.dataset = hf_dataset
        self.processor = processor
        self.device = device
        
    def __len__(self):
        return len(self.dataset)
    
    # manaully processes each image so pixel values are always returned ('image' is no longer needed for Trainer)
    def __getitem__(self, idx):
        item = self.dataset[idx]
        inputs = self.processor(images=item['image'], return_tensors="pt")
        return {
            'pixel_values': inputs['pixel_values'].squeeze(0),
            'labels': torch.tensor(item['label'])
        }

### Pre-Processing

In [4]:
# Turn dicts into tensors
def collate_fn(batch):
    pixel_values = torch.stack([x['pixel_values'] for x in batch])
    labels = torch.tensor([x['labels'] for x in batch])
    return {
        'pixel_values': pixel_values.to(device),
        'labels': labels.to(device)
    }

In [5]:
# Define metrics
acc_metric = evaluate.load('accuracy')
f1_metric = evaluate.load('f1')

def compute_metrics(p):
    acc = acc_metric.compute(predictions=np.argmax(p.predictions, axis=1), references=p.label_ids)
    f1 = f1_metric.compute(predictions=np.argmax(p.predictions, axis=1), references=p.label_ids)
    return {"Accuracy": acc["accuracy"], "F1": f1["f1"]}

### Create Few-Shot Set

In [6]:
def create_few_shot_set(dataset, set_size, seed): # set_size = 10 -> 5-shot learning (2 classes)
    random.seed(seed)

    newset_indices = []

    label0_indices = [i for i, label in enumerate(dataset['train']['label']) if label == 0]
    label1_indices = [i for i, label in enumerate(dataset['train']['label']) if label == 1]

    # get equal number of samples from each class
    newset0_indices = random.sample(label0_indices, set_size // 2)
    newset1_indices = random.sample(label1_indices, set_size // 2)
    newset_indices = newset0_indices + newset1_indices

    random.shuffle(newset_indices)

    few_shot_train = dataset['train'].select(newset_indices)
    val_set = dataset['validation'] # can use the full validation set

    return DatasetDict({
        'train': few_shot_train,
        'validation': val_set
    })

print(type(ds))

few_shot_ds = create_few_shot_set(ds, set_size=50, seed=42)

print(type(few_shot_ds))
print(few_shot_ds)

train_dataset = CustomDataset(few_shot_ds['train'], processor, device)
val_dataset = CustomDataset(few_shot_ds['validation'], processor, device)

# sample = train_dataset[0]
# print("Pixel values shape:", sample['pixel_values'].shape)
# print("Label:", sample['labels'])

<class 'datasets.dataset_dict.DatasetDict'>
<class 'datasets.dataset_dict.DatasetDict'>
DatasetDict({
    train: Dataset({
        features: ['image', 'label'],
        num_rows: 50
    })
    validation: Dataset({
        features: ['image', 'label'],
        num_rows: 1600
    })
})


### Load Model

In [7]:
# Use current dataset to extract the labels (doesn't really matter which dataset we use)
labels = ds['train'].features['label'].names
print(labels[0:2])

# Load the pre-trained model
model = SwinForImageClassification.from_pretrained(
    model_path, 
    num_labels=len(labels),
    id2label={str(i): c for i, c in enumerate(labels)},
    label2id={c: str(i) for i, c in enumerate(labels)}
).to(device)

['0', '1']


### Prepare Model For Training

In [8]:
# Freezing earlier layers
print(summary(model, input_size=(1, 3, 224, 224)))

for param in model.parameters():
    param.requires_grad = False # freeze all layers

for param in model.classifier.parameters():
    param.requires_grad = True # unfreeze the classifier layer

for param in model.swin.encoder.layers[-1].parameters():
    param.requires_grad = True # unfreeze the last layer of the encoder

Layer (type:depth-idx)                                            Output Shape              Param #
SwinForImageClassification                                        [1, 2]                    --
├─SwinModel: 1-1                                                  [1, 1024]                 --
│    └─SwinEmbeddings: 2-1                                        [1, 3136, 128]            --
│    │    └─SwinPatchEmbeddings: 3-1                              [1, 3136, 128]            6,272
│    │    └─LayerNorm: 3-2                                        [1, 3136, 128]            256
│    │    └─Dropout: 3-3                                          [1, 3136, 128]            --
│    └─SwinEncoder: 2-2                                           [1, 49, 1024]             --
│    │    └─ModuleList: 3-4                                       --                        86,734,648
│    └─LayerNorm: 2-3                                             [1, 49, 1024]             2,048
│    └─AdaptiveAvgPool1d: 2-4 

In [9]:
# Training arguments
training_args = TrainingArguments(
    output_dir=model_path + "_few_shot_fr",
    per_device_train_batch_size=4,
    per_device_eval_batch_size=8,
    #gradient_accumulation_steps=4,  # mimics having a larger batch size and makes training more stable
    num_train_epochs=3,
    learning_rate=1e-5, # need low learning rate to avoid forgetting of original dataset stuff
    logging_dir="./logs",
    evaluation_strategy="epoch",
    save_strategy="no",
    logging_steps=1,
    load_best_model_at_end=False,
    report_to="none",
    dataloader_pin_memory=False, # causes runtime error otherwise
)



In [10]:
# Trainer 
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    data_collator=collate_fn,
    compute_metrics=compute_metrics,
    tokenizer=processor,
)

  trainer = Trainer(


### Training and Evaluation

In [11]:
train_results = trainer.train()
trainer.save_model()
trainer.log_metrics("train", train_results.metrics)
trainer.save_metrics("train", train_results.metrics)
trainer.save_state()

metrics = trainer.evaluate(val_dataset)
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)


  5%|▌         | 2/39 [00:00<00:10,  3.56it/s]

{'loss': 1.7046, 'grad_norm': 11.971308708190918, 'learning_rate': 9.743589743589744e-06, 'epoch': 0.08}


  8%|▊         | 3/39 [00:00<00:08,  4.16it/s]

{'loss': 4.4136, 'grad_norm': 20.38154411315918, 'learning_rate': 9.487179487179487e-06, 'epoch': 0.15}


 10%|█         | 4/39 [00:00<00:07,  4.52it/s]

{'loss': 4.9174, 'grad_norm': 20.496736526489258, 'learning_rate': 9.230769230769232e-06, 'epoch': 0.23}


 13%|█▎        | 5/39 [00:01<00:07,  4.77it/s]

{'loss': 8.8523, 'grad_norm': 33.27217102050781, 'learning_rate': 8.974358974358976e-06, 'epoch': 0.31}


 15%|█▌        | 6/39 [00:01<00:06,  4.93it/s]

{'loss': 4.2191, 'grad_norm': 17.458572387695312, 'learning_rate': 8.717948717948719e-06, 'epoch': 0.38}


 18%|█▊        | 7/39 [00:01<00:06,  5.00it/s]

{'loss': 4.8329, 'grad_norm': 19.817663192749023, 'learning_rate': 8.461538461538462e-06, 'epoch': 0.46}


 21%|██        | 8/39 [00:01<00:06,  5.07it/s]

{'loss': 3.1196, 'grad_norm': 15.663534164428711, 'learning_rate': 8.205128205128205e-06, 'epoch': 0.54}


 23%|██▎       | 9/39 [00:01<00:05,  5.08it/s]

{'loss': 6.0908, 'grad_norm': 22.821640014648438, 'learning_rate': 7.948717948717949e-06, 'epoch': 0.62}


 26%|██▌       | 10/39 [00:02<00:05,  5.14it/s]

{'loss': 7.2526, 'grad_norm': 31.364831924438477, 'learning_rate': 7.692307692307694e-06, 'epoch': 0.69}


 26%|██▌       | 10/39 [00:02<00:05,  5.14it/s]

{'loss': 4.4207, 'grad_norm': 21.87225914001465, 'learning_rate': 7.435897435897437e-06, 'epoch': 0.77}


 31%|███       | 12/39 [00:02<00:05,  5.27it/s]

{'loss': 6.967, 'grad_norm': 37.18930435180664, 'learning_rate': 7.17948717948718e-06, 'epoch': 0.85}
{'loss': 2.9788, 'grad_norm': 20.67455291748047, 'learning_rate': 6.923076923076923e-06, 'epoch': 0.92}


 33%|███▎      | 13/39 [00:02<00:04,  6.05it/s]

{'loss': 7.996, 'grad_norm': 36.524375915527344, 'learning_rate': 6.666666666666667e-06, 'epoch': 1.0}


                                               
 33%|███▎      | 13/39 [00:54<00:04,  6.05it/s]  

{'eval_loss': 4.713573932647705, 'eval_Accuracy': 0.23875, 'eval_F1': 0.02090032154340836, 'eval_runtime': 51.6207, 'eval_samples_per_second': 30.995, 'eval_steps_per_second': 3.874, 'epoch': 1.0}


 38%|███▊      | 15/39 [00:54<04:25, 11.08s/it]

{'loss': 5.9891, 'grad_norm': 33.825714111328125, 'learning_rate': 6.410256410256412e-06, 'epoch': 1.08}


 41%|████      | 16/39 [00:54<02:59,  7.80s/it]

{'loss': 5.1009, 'grad_norm': 24.300981521606445, 'learning_rate': 6.153846153846155e-06, 'epoch': 1.15}


 44%|████▎     | 17/39 [00:55<02:01,  5.52s/it]

{'loss': 6.2494, 'grad_norm': 25.951318740844727, 'learning_rate': 5.897435897435898e-06, 'epoch': 1.23}


 46%|████▌     | 18/39 [00:55<01:22,  3.92s/it]

{'loss': 3.2935, 'grad_norm': 21.01800537109375, 'learning_rate': 5.641025641025641e-06, 'epoch': 1.31}


 46%|████▌     | 18/39 [00:55<01:22,  3.92s/it]

{'loss': 4.3378, 'grad_norm': 19.87311553955078, 'learning_rate': 5.384615384615385e-06, 'epoch': 1.38}


 51%|█████▏    | 20/39 [00:55<00:38,  2.02s/it]

{'loss': 5.0326, 'grad_norm': 23.236135482788086, 'learning_rate': 5.128205128205128e-06, 'epoch': 1.46}


 54%|█████▍    | 21/39 [00:55<00:26,  1.48s/it]

{'loss': 5.7995, 'grad_norm': 21.522876739501953, 'learning_rate': 4.871794871794872e-06, 'epoch': 1.54}


 56%|█████▋    | 22/39 [00:56<00:18,  1.09s/it]

{'loss': 4.6683, 'grad_norm': 20.781705856323242, 'learning_rate': 4.615384615384616e-06, 'epoch': 1.62}


 59%|█████▉    | 23/39 [00:56<00:13,  1.21it/s]

{'loss': 5.071, 'grad_norm': 19.936500549316406, 'learning_rate': 4.358974358974359e-06, 'epoch': 1.69}


 62%|██████▏   | 24/39 [00:56<00:09,  1.57it/s]

{'loss': 2.1666, 'grad_norm': 13.546972274780273, 'learning_rate': 4.102564102564103e-06, 'epoch': 1.77}


 64%|██████▍   | 25/39 [00:56<00:06,  2.01it/s]

{'loss': 2.4738, 'grad_norm': 22.658658981323242, 'learning_rate': 3.846153846153847e-06, 'epoch': 1.85}
{'loss': 4.2422, 'grad_norm': 19.864152908325195, 'learning_rate': 3.58974358974359e-06, 'epoch': 1.92}


 67%|██████▋   | 26/39 [00:56<00:04,  2.64it/s]

{'loss': 1.6224, 'grad_norm': 14.71154499053955, 'learning_rate': 3.3333333333333333e-06, 'epoch': 2.0}




RuntimeError: CUDA error: misaligned address
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
