In [1]:
import matplotlib.pyplot as plt
import numpy as np
import random
from collections import Counter
import pickle

import torchvision.models as models
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torch.optim as optim
from tqdm.notebook import tqdm


from transformers import ViTFeatureExtractor, ViTForImageClassification
from PIL import Image
from torch.utils.data import Subset
from datasets import Dataset

from transformers import TrainingArguments
from transformers import Trainer

from datasets import load_metric

from torch.utils.data import DataLoader

import pytorch_lightning as pl
from torchmetrics import Accuracy
from pytorch_lightning.callbacks import ModelCheckpoint

2022-12-24 18:11:44.974855: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcudart.so.10.1


In [2]:
images, labels = pickle.load(open('images_10_3_normal.p', 'rb'))

In [3]:
images.shape

(11177, 100, 100, 3)

In [4]:
# Split the classes in a balanced way
def train_test_split(labels, split_ratio=0.8):
    # Calculate the class frequencies
    class_counts = Counter(labels)

    # Calculate the number of samples in the train and test sets
    num_samples = len(labels)
    split_ratio = split_ratio 
    num_train_samples = int(num_samples * split_ratio)
    num_test_samples = num_samples - num_train_samples

    # Create a list of tuples, where each tuple contains the class label and the corresponding indices of the samples
    label_indices = [(label, np.where(labels == label)[0]) for label in class_counts.keys()]

    # Initialize the train and test sets
    train_indices = []
    test_indices = []

    # Loop over the list of tuples
    for label, indices in label_indices:
        # Calculate the number of samples for this class
        num_samples = len(indices)

        # Calculate the number of samples in the train and test sets for this class
        num_train_samples = int(num_samples * split_ratio)
        num_test_samples = num_samples - num_train_samples

        # Select the train and test indices for this class
        train_indices += random.sample(list(indices), num_train_samples)
        test_indices += [i for i in indices if i not in train_indices]
    return train_indices, test_indices

In [None]:
# check if they are balanced?

train_indices, test_indices = train_test_split(labels, split_ratio=0.4)

np.unique(labels[train_indices], return_counts=True), np.unique(labels[test_indices], return_counts=True)

In [None]:
num_classes = len(np.unique(labels, return_counts=False))
num_classes

In [None]:
Image.fromarray(images[1])

In [None]:
model_name_or_path = 'google/vit-base-patch16-224'
feature_extractor = ViTFeatureExtractor.from_pretrained(model_name_or_path)

In [None]:
label2id = {}
id2label = {}

for i, class_name in enumerate(range(num_classes)):
    label2id[class_name] = str(i)
    id2label[str(i)] = class_name
    
id2label, label2id

In [None]:
class ImageClassificationCollator:
    def __init__(self, feature_extractor):
        self.feature_extractor = feature_extractor

    def __call__(self, batch):
        encodings = self.feature_extractor([x[0] for x in batch], return_tensors='pt')
        encodings['labels'] = torch.tensor([x[1] for x in batch], dtype=torch.long)
        return encodings 

In [None]:
class DDSM(torch.utils.data.Dataset):
    def __init__(self, images, labels, transform=None):
        self.images = images
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image, label = Image.fromarray(self.images[idx]), self.labels[idx]
        if self.transform:
            image = self.transform(image)
        return image, label

In [None]:
# create data loaders

train_ds = DDSM(images[train_indices], labels[train_indices])
val_ds = DDSM(images[test_indices], labels[test_indices])

In [None]:
collator = ImageClassificationCollator(feature_extractor)

train_loader = DataLoader(train_ds, batch_size=8, collate_fn=collator, shuffle=True)
val_loader = DataLoader(val_ds, batch_size=8, collate_fn=collator)

model = ViTForImageClassification.from_pretrained(
    model_name_or_path,
    num_labels=len(label2id),
    label2id=label2id,
    id2label=id2label,
    ignore_mismatched_sizes=True
)

In [None]:
d = next(iter(train_loader))
model(**d)

In [None]:
class Classifier(pl.LightningModule):

    def __init__(self, model, lr: float = 2e-5, **kwargs):
        super().__init__()
        self.save_hyperparameters('lr', *list(kwargs))
        self.model = model
        self.forward = self.model.forward
        self.val_acc = Accuracy()
        self.train_acc= Accuracy()

    def training_step(self, batch, batch_idx):
        outputs = self(**batch)
        self.log(f"train_loss", outputs.loss)
        acc1 = self.train_acc(outputs.logits.argmax(1), batch['labels'])
        self.log(f"train_acc", acc1, prog_bar=True)
        return outputs.loss

    def validation_step(self, batch, batch_idx):
        outputs = self(**batch)
        self.log(f"val_loss", outputs.loss)
        acc = self.val_acc(outputs.logits.argmax(1), batch['labels'])
        self.log(f"val_acc", acc, prog_bar=True)
        return outputs.loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.hparams.lr,weight_decay=0.0025)

In [None]:
pl.seed_everything(42)
classifier = Classifier(model, lr=2e-5)
checkpoint_callback = ModelCheckpoint(
    monitor='val_loss',
    dirpath='./vit_content/trainmebby',
    filename='ViT-{epoch:02d}-{val_loss:.2f}',
)
trainer = pl.Trainer(callbacks=[checkpoint_callback], max_epochs=3, gpus=1, precision=16)
trainer.fit(classifier, train_loader, val_loader)