In [1]:
import torch
from torch import nn
import matplotlib.pyplot as plt
import numpy as np
from datasets import load_dataset
from torch.utils.data import DataLoader
from torchvision import transforms
from PIL import Image
import torchvision

In [2]:
model = torchvision.models.resnet18(weights=None, num_classes=200)
dataset_train = load_dataset("json", data_files="train_data.jsonl", split="train")
dataset_val = load_dataset("json", data_files="val_data.jsonl", split="train")
dataset_train = dataset_train.with_format("torch")
dataset_val = dataset_val.with_format("torch")

In [3]:
dataset_train

Dataset({
    features: ['image_path', 'class_name', 'class_label'],
    num_rows: 100000
})

In [4]:
# compose = transforms.Compose(
#     [
#         transforms.Resize((224, 224)),
#         transforms.ToTensor(),
#         transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
#     ]
# )

# torch.set_num_threads(1)


# def preprocess(batch):
#     images = []
#     for x in batch["image"]:
#         image = transforms.ToPILImage()(x)
#         if image.mode != "RGB":
#             image = image.convert("RGB")
#         image = compose(image)
#         images.append(image)
#     return {
#         "image": images,
#         "label": [x for x in batch["label"]],
#     }


# dataset = dataset.map(preprocess, batched=True, batch_size=100, num_proc=4)

In [5]:
# dataset.save_to_disk("tiny-imagenet-processed")

In [6]:
# def preprocess(batch):
#     images = []
#     for x in batch["image"]:
#         if x.shape[0] == 1:
#             x = x.repeat(3, 1, 1)
#         images.append(x.type(torch.uint8))
#     return {
#         "image": torch.stack(images),
#         "label": torch.tensor([x for x in batch["label"]]),
#     }


# dataset = dataset.map(preprocess, batched=True, batch_size=100, num_proc=4)

In [7]:
labels2id = dict()
c = 0
for i in range(len(dataset_val)):
    label = dataset_val[i]["class_name"]
    if label not in labels2id:
        labels2id[label] = c
        c += 1

In [8]:
# apply labels to dataset
def apply_labels(batch):
    labels = []
    images = []
    for x in batch["class_name"]:
        labels.append(labels2id[x])
    for image_path in batch["image_path"]:
        img = Image.open(image_path)
        if img.mode != "RGB":
            img = img.convert("RGB")
        img = transforms.PILToTensor()(img).to(torch.float32)
        img = transforms.Normalize(
            mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
        )(img)
        images.append(img)
    return {
        "image": torch.stack(images),
        "label": torch.tensor([(x) for x in labels]),
    }


dataset_train = dataset_train.map(apply_labels, batched=True, num_proc=4)
dataset_val = dataset_val.map(apply_labels, batched=True, num_proc=4)

In [9]:
dataset_train[0]

{'image_path': './tiny-imagenet-200/train/n01443537/images/n01443537_0.JPEG',
 'class_name': 'n01443537',
 'class_label': 'goldfish, Carassius auratus',
 'image': tensor([[[1111.4192, 1111.4192, 1085.2184,  ...,  906.1790,  910.5458,
            552.4672],
          [1111.4192, 1107.0524, 1072.1179,  ...,  893.0786,  879.9781,
            535.0000],
          [1111.4192, 1111.4192, 1111.4192,  ...,  901.8122,  879.9781,
            521.8995],
          ...,
          [ 412.7292,  403.9956,  403.9956,  ...,  369.0611,  338.4934,
            408.3624],
          [ 386.5284,  386.5284,  386.5284,  ...,  373.4279,  334.1266,
            377.7947],
          [ 395.2620,  395.2620,  403.9956,  ...,  417.0961,  334.1266,
            312.2926]],
 
         [[ 605.1072,  614.0357,  649.7500,  ..., 1056.0000, 1073.8572,
            725.6429],
          [ 560.4642,  569.3928,  609.5714,  ..., 1029.2142, 1042.6072,
            707.7857],
          [ 560.4642,  573.8572,  600.6429,  ..., 1038.1428,

In [10]:
train_loader = DataLoader(dataset_train, batch_size=128, shuffle=True)
val_loader = DataLoader(dataset_val, batch_size=128)

for x in train_loader:
    print(x)
    break

{'image_path': ['./tiny-imagenet-200/train/n07749582/images/n07749582_102.JPEG', './tiny-imagenet-200/train/n04532670/images/n04532670_59.JPEG', './tiny-imagenet-200/train/n09193705/images/n09193705_248.JPEG', './tiny-imagenet-200/train/n04118538/images/n04118538_60.JPEG', './tiny-imagenet-200/train/n07920052/images/n07920052_335.JPEG', './tiny-imagenet-200/train/n02769748/images/n02769748_77.JPEG', './tiny-imagenet-200/train/n02837789/images/n02837789_15.JPEG', './tiny-imagenet-200/train/n04023962/images/n04023962_19.JPEG', './tiny-imagenet-200/train/n04099969/images/n04099969_84.JPEG', './tiny-imagenet-200/train/n02281406/images/n02281406_5.JPEG', './tiny-imagenet-200/train/n02074367/images/n02074367_389.JPEG', './tiny-imagenet-200/train/n02125311/images/n02125311_372.JPEG', './tiny-imagenet-200/train/n01882714/images/n01882714_452.JPEG', './tiny-imagenet-200/train/n02883205/images/n02883205_450.JPEG', './tiny-imagenet-200/train/n09332890/images/n09332890_315.JPEG', './tiny-imagenet-

In [11]:
NUM_EPOCHS = 100
LEARNING_RATE = 0.001
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
loss = nn.CrossEntropyLoss()

model = torch.compile(model, mode="max-autotune")
model.to(device)
model.train()

train_loss = []
val_loss = []

In [12]:
from sklearn.metrics import accuracy_score, f1_score, classification_report
from tqdm.notebook import tqdm

torch.set_float32_matmul_precision("high")
for epoch in range(NUM_EPOCHS):
    model.train()
    train_loss.append(0)
    for i, batch in enumerate(tqdm(train_loader)):
        images = batch["image"]
        labels = batch["label"]
        images = images.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()
        try:
            outputs = model(images)
        except Exception as e:
            print(images.dtype)
            raise e
        loss_val = loss(outputs, labels)
        train_loss[-1] += loss_val.item()
        loss_val.backward()
        optimizer.step()

    model.eval()

    val_loss.append(0)
    preds = []
    labs = []
    with torch.no_grad():
        for i, batch in enumerate(val_loader):
            images = batch["image"]
            labels = batch["label"]
            images = images.to(device)
            labels = labels.to(device)

            outputs = model(images)
            preds.append(outputs.argmax(dim=1).cpu().numpy())
            labs.append(labels.cpu().numpy())
            loss_val = loss(outputs, labels)
            val_loss[-1] += loss_val.item()

    train_loss[-1] /= len(train_loader)
    val_loss[-1] /= len(val_loader)

    preds = np.concatenate(preds)
    labs = np.concatenate(labs)
    acc = accuracy_score(labs, preds)
    f1 = f1_score(labs, preds, average="macro")
    print(classification_report(labs, preds))
    print(
        f"Epoch {epoch} - Train Loss: {train_loss[-1]} - Val Loss: {val_loss[-1]} - Acc: {acc} - F1: {f1}"
    )

  0%|          | 0/782 [00:00<?, ?it/s]

W0903 17:57:13.341000 140562075584320 torch/_inductor/utils.py:977] [0/0] Not enough SMs to use max_autotune_gemm mode
AUTOTUNE addmm(128x200, 128x512, 512x200)
  addmm 0.0902 ms 100.0%
  bias_addmm 0.0910 ms 99.1%
SingleProcess AUTOTUNE benchmarking takes 0.3588 seconds and 0.0000 seconds precompiling
AUTOTUNE addmm(32x200, 32x512, 512x200)
  bias_addmm 0.0821 ms 100.0%
  addmm 0.0824 ms 99.6%
SingleProcess AUTOTUNE benchmarking takes 0.2002 seconds and 0.0000 seconds precompiling


              precision    recall  f1-score   support

           0       0.38      0.10      0.16        50
           1       0.50      0.02      0.04        50
           2       0.12      0.08      0.10        50
           3       0.31      0.08      0.13        50
           4       0.40      0.04      0.07        50
           5       0.02      0.02      0.02        50
           6       0.00      0.00      0.00        50
           7       0.16      0.10      0.12        50
           8       0.50      0.02      0.04        50
           9       0.00      0.00      0.00        50
          10       0.11      0.08      0.09        50
          11       0.04      0.02      0.03        50
          12       0.00      0.00      0.00        50
          13       0.00      0.00      0.00        50
          14       0.24      0.14      0.18        50
          15       0.26      0.10      0.14        50
          16       0.14      0.46      0.22        50
          17       0.09    

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


  0%|          | 0/782 [00:00<?, ?it/s]

              precision    recall  f1-score   support

           0       0.29      0.50      0.37        50
           1       0.00      0.00      0.00        50
           2       0.11      0.16      0.13        50
           3       0.61      0.28      0.38        50
           4       0.00      0.00      0.00        50
           5       0.10      0.12      0.11        50
           6       0.00      0.00      0.00        50
           7       0.22      0.46      0.29        50
           8       0.42      0.32      0.36        50
           9       0.00      0.00      0.00        50
          10       0.18      0.08      0.11        50
          11       0.28      0.20      0.23        50
          12       0.00      0.00      0.00        50
          13       0.11      0.02      0.03        50
          14       0.49      0.66      0.56        50
          15       0.17      0.44      0.24        50
          16       0.55      0.58      0.56        50
          17       0.35    

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


  0%|          | 0/782 [00:00<?, ?it/s]

              precision    recall  f1-score   support

           0       0.45      0.46      0.46        50
           1       0.07      0.12      0.09        50
           2       0.35      0.14      0.20        50
           3       0.50      0.20      0.29        50
           4       0.14      0.02      0.04        50
           5       0.14      0.04      0.06        50
           6       0.02      0.02      0.02        50
           7       0.31      0.46      0.37        50
           8       0.45      0.34      0.39        50
           9       0.00      0.00      0.00        50
          10       0.25      0.18      0.21        50
          11       0.16      0.38      0.23        50
          12       0.25      0.26      0.26        50
          13       0.35      0.12      0.18        50
          14       0.47      0.74      0.58        50
          15       0.11      0.58      0.18        50
          16       0.30      0.64      0.41        50
          17       0.13    

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


  0%|          | 0/782 [00:00<?, ?it/s]

              precision    recall  f1-score   support

           0       0.69      0.48      0.56        50
           1       0.10      0.12      0.11        50
           2       0.24      0.18      0.20        50
           3       0.43      0.46      0.45        50
           4       0.33      0.02      0.04        50
           5       0.14      0.06      0.08        50
           6       0.07      0.02      0.03        50
           7       0.27      0.62      0.37        50
           8       0.22      0.42      0.29        50
           9       0.06      0.02      0.03        50
          10       0.36      0.18      0.24        50
          11       0.38      0.30      0.33        50
          12       0.14      0.36      0.20        50
          13       0.32      0.14      0.19        50
          14       0.29      0.84      0.43        50
          15       0.23      0.26      0.24        50
          16       0.80      0.48      0.60        50
          17       0.24    

  0%|          | 0/782 [00:00<?, ?it/s]

              precision    recall  f1-score   support

           0       0.46      0.66      0.54        50
           1       0.25      0.06      0.10        50
           2       0.42      0.10      0.16        50
           3       0.50      0.36      0.42        50
           4       0.15      0.18      0.16        50
           5       0.17      0.02      0.04        50
           6       0.11      0.06      0.08        50
           7       0.39      0.22      0.28        50
           8       0.25      0.50      0.34        50
           9       0.11      0.18      0.14        50
          10       0.39      0.24      0.30        50
          11       0.62      0.16      0.25        50
          12       0.26      0.36      0.30        50
          13       0.32      0.12      0.17        50
          14       0.56      0.70      0.62        50
          15       0.26      0.22      0.24        50
          16       0.86      0.48      0.62        50
          17       0.32    

  0%|          | 0/782 [00:00<?, ?it/s]

              precision    recall  f1-score   support

           0       0.77      0.46      0.57        50
           1       0.20      0.12      0.15        50
           2       0.12      0.36      0.18        50
           3       0.65      0.34      0.45        50
           4       0.13      0.24      0.17        50
           5       0.13      0.10      0.11        50
           6       0.07      0.10      0.08        50
           7       0.37      0.44      0.40        50
           8       0.39      0.44      0.41        50
           9       0.13      0.08      0.10        50
          10       0.16      0.40      0.23        50
          11       0.22      0.38      0.28        50
          12       0.35      0.12      0.18        50
          13       0.40      0.16      0.23        50
          14       0.65      0.56      0.60        50
          15       0.35      0.30      0.32        50
          16       0.28      0.72      0.41        50
          17       0.25    

  0%|          | 0/782 [00:00<?, ?it/s]

              precision    recall  f1-score   support

           0       0.70      0.38      0.49        50
           1       0.08      0.18      0.11        50
           2       0.34      0.24      0.28        50
           3       0.27      0.46      0.34        50
           4       0.29      0.08      0.12        50
           5       0.18      0.04      0.07        50
           6       0.10      0.10      0.10        50
           7       0.68      0.30      0.42        50
           8       0.50      0.40      0.44        50
           9       0.13      0.10      0.11        50
          10       0.17      0.26      0.20        50
          11       0.45      0.26      0.33        50
          12       0.24      0.34      0.28        50
          13       0.31      0.18      0.23        50
          14       0.70      0.62      0.66        50
          15       0.31      0.30      0.30        50
          16       0.65      0.56      0.60        50
          17       0.29    

  0%|          | 0/782 [00:00<?, ?it/s]

              precision    recall  f1-score   support

           0       0.59      0.48      0.53        50
           1       0.20      0.20      0.20        50
           2       0.36      0.24      0.29        50
           3       0.48      0.44      0.46        50
           4       0.15      0.12      0.13        50
           5       0.21      0.14      0.17        50
           6       0.03      0.02      0.03        50
           7       0.35      0.26      0.30        50
           8       0.37      0.58      0.45        50
           9       0.08      0.24      0.12        50
          10       0.24      0.20      0.22        50
          11       0.41      0.26      0.32        50
          12       0.25      0.32      0.28        50
          13       0.26      0.24      0.25        50
          14       0.66      0.62      0.64        50
          15       0.26      0.42      0.32        50
          16       0.58      0.64      0.61        50
          17       0.23    

  0%|          | 0/782 [00:00<?, ?it/s]

              precision    recall  f1-score   support

           0       0.65      0.60      0.62        50
           1       0.29      0.14      0.19        50
           2       0.32      0.32      0.32        50
           3       0.28      0.44      0.34        50
           4       0.16      0.24      0.19        50
           5       0.09      0.10      0.10        50
           6       0.11      0.08      0.09        50
           7       0.40      0.34      0.37        50
           8       0.36      0.42      0.39        50
           9       0.13      0.14      0.13        50
          10       0.29      0.28      0.28        50
          11       0.35      0.32      0.33        50
          12       0.30      0.32      0.31        50
          13       0.24      0.38      0.29        50
          14       0.64      0.74      0.69        50
          15       0.23      0.28      0.25        50
          16       0.66      0.58      0.62        50
          17       0.43    

  0%|          | 0/782 [00:00<?, ?it/s]

              precision    recall  f1-score   support

           0       0.63      0.52      0.57        50
           1       0.12      0.14      0.13        50
           2       0.23      0.22      0.23        50
           3       0.31      0.42      0.36        50
           4       0.13      0.16      0.14        50
           5       0.15      0.16      0.16        50
           6       0.07      0.04      0.05        50
           7       0.27      0.32      0.29        50
           8       0.57      0.34      0.42        50
           9       0.24      0.10      0.14        50
          10       0.29      0.20      0.24        50
          11       0.16      0.38      0.22        50
          12       0.38      0.24      0.29        50
          13       0.26      0.20      0.22        50
          14       0.79      0.54      0.64        50
          15       0.24      0.32      0.27        50
          16       0.31      0.70      0.43        50
          17       0.19    

  0%|          | 0/782 [00:00<?, ?it/s]

KeyboardInterrupt: 