<a href="https://colab.research.google.com/github/jlb-jlb/ML_Notebooks/blob/main/Bachelor_ViT_Seizure_Training.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Fine-Tuning Vision Transformer for Seizure Classification

## Installing Dependencies

In [1]:
# blocks output in Colab 💄
%%capture

! pip install -q datasets transformers[torch]

! pip install -q evaluate

## Huggingface login

In [2]:
from huggingface_hub import notebook_login

notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

## Google Login

In [3]:
from google.colab import drive
# "/content/drive/MyDrive/Model_folder"
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


# Load Dataset

In [None]:
from datasets import load_dataset

# load dataset
# dataset_train = load_dataset("JLB-JLB/seizure_eeg_train", split="train")   # , streaming=True)
# dataset_dev = load_dataset("JLB-JLB/seizure_eeg_dev", split="train")   # , streaming=True)
# dataset_eval = load_dataset("JLB-JLB/seizure_eeg_eval", split="train")   # , streaming=True)

# display(dataset_train)
# display(dataset_dev)
# display(dataset_eval)


# dataset = load_dataset("JLB-JLB/seizure_eeg_greyscale_224x224_6secWindow")  # Load from Huggingface
# save dataset to google drive
# dataset.save_to_disk("/content/drive/MyDrive/Seizure_EEG_Research/Dataset") # save to drive

dataset = load_dataset("/content/drive/MyDrive/Seizure_EEG_Research/Dataset") # load from drive



Resolving data files:   0%|          | 0/51 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/26 [00:00<?, ?it/s]

Downloading data files:   0%|          | 0/3 [00:00<?, ?it/s]

Extracting data files:   0%|          | 0/3 [00:00<?, ?it/s]

Generating train split: 0 examples [00:00, ? examples/s]

In [None]:
dataset

In [None]:
example = dataset["train"][400]
example

In [None]:
image = example["image"]
image

In [None]:
labels = dataset["train"].features["label"]
labels

In [None]:
labels.int2str(example['label'])

In [None]:
import random
from PIL import ImageDraw, ImageFont, Image

def show_examples(ds, seed: int = 1234, examples_per_class: int = 3, size=(350, 350)):

    w, h = size
    labels = ds['train'].features['label'].names
    grid = Image.new('RGB', size=(examples_per_class * w, len(labels) * h))
    draw = ImageDraw.Draw(grid)
    # font = ImageFont.truetype("/usr/share/fonts/truetype/liberation/LiberationMono-Bold.ttf", 24)

    for label_id, label in enumerate(labels):

        # Filter the dataset by a single label, shuffle it, and grab a few samples
        ds_slice = ds['train'].filter(lambda ex: ex['label'] == label_id).shuffle(seed).select(range(examples_per_class))

        # Plot this label's examples along a row
        for i, example in enumerate(ds_slice):
            image = example['image']
            idx = examples_per_class * label_id + i
            box = (idx % examples_per_class * w, idx // examples_per_class * h)
            grid.paste(image.resize(size), box=box)
            draw.text(box, label, (255, 255, 255)) #, font=font)

    return grid

# show_examples(dataset, seed=random.randint(0, 1337), examples_per_class=3)


## Load model

In [None]:
from transformers import ViTFeatureExtractor

model_name_or_path = 'google/vit-base-patch16-224-in21k'
feature_extractor = ViTFeatureExtractor.from_pretrained(model_name_or_path)
feature_extractor

In [None]:
feature_extractor(image.convert("RGB"), return_tensors='pt')

In [None]:
def process_example(example):
    inputs = feature_extractor(example['image'].convert("RGB"), return_tensors='pt')
    inputs['label'] = example['label']
    return inputs

In [None]:
process_example(dataset['train'][0])

In [None]:
def transform(example_batch):
    # Take a list of PIL images and turn them to pixel values
    inputs = feature_extractor([x.convert("RGB") for x in example_batch['image']], return_tensors='pt')

    # Don't forget to include the labels!
    inputs['label'] = example_batch['label']
    return inputs


prepared_ds = dataset.with_transform(transform)

In [None]:
prepared_ds["train"][0:2]

## Training and Eval

In [None]:
import torch

def collate_fn(batch):
    return {
        'pixel_values': torch.stack([x['pixel_values'] for x in batch]),
        'labels': torch.tensor([x['label'] for x in batch])
    }

In [None]:
# device will determine whether to run the training on GPU or CPU.
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

In [None]:
import numpy as np
import evaluate

metric_matthews_corrcoef = evaluate.load("matthews_correlation")

# metric

def compute_metrics(eval_preds):
    predictions = np.argmax(eval_preds.predictions, axis=1)
    references = eval_preds.label_ids
    return metric_matthews_corrcoef.compute(predictions=predictions, references=references)

print(metric_matthews_corrcoef.description)

In [None]:
from transformers import ViTForImageClassification

labels = dataset['train'].features['label'].names

model = ViTForImageClassification.from_pretrained(
    model_name_or_path,
    num_labels=len(labels),
    id2label={str(i): c for i, c in enumerate(labels)},
    label2id={c: str(i) for i, c in enumerate(labels)}
)

In [None]:
model.to(device)

In [None]:
from transformers import TrainingArguments

training_args = TrainingArguments(
  output_dir= "/content/drive/MyDrive/Seizure_EEG_Research/ViT_Seizure_Detection", #"./model_vit_eeg",
  per_device_train_batch_size=64, # T4: 64, V100: 64, A100: 256
  evaluation_strategy="steps",
  num_train_epochs=4,
#   fp16=True,
  save_steps=5000,
  eval_steps=5000,
  logging_steps=20,
  learning_rate=1e-4,
  save_total_limit=4,
  remove_unused_columns=False,
  resume_from_checkpoint=True,
  push_to_hub=True,
  report_to='tensorboard',
  load_best_model_at_end=True,
)

In [None]:
prepared_ds

In [None]:
from transformers import Trainer

trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=collate_fn,
    compute_metrics=compute_metrics,
    train_dataset=prepared_ds["train"],
    eval_dataset=prepared_ds["test"],
    tokenizer=feature_extractor,
)

In [None]:
train_results = trainer.train() # "/content/drive/MyDrive/Model_folder/checkpoint-30")
trainer.save_model()
trainer.log_metrics("train", train_results.metrics)
trainer.save_metrics("train", train_results.metrics)
trainer.save_state()

In [None]:
metrics = trainer.evaluate(prepared_ds['validation'])
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)

In [None]:
kwargs = {
    "finetuned_from": model.config._name_or_path,
    "tasks": "image-classification",
    "dataset": 'beans',
    "tags": ['image-classification'],
}

if training_args.push_to_hub:
    trainer.push_to_hub('VIT_SEIZURE_DETECTION', **kwargs)
else:
    trainer.create_model_card(**kwargs)