In [None]:
import warnings
warnings.filterwarnings('ignore')

In [None]:
!pip install lightning --quiet

In [None]:
import random
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

sns.set_theme(style="darkgrid", font_scale=1.5, font="SimHei", rc={"axes.unicode_minus":False})

import torch
import torchmetrics
from torch import nn, optim
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision import transforms, datasets, models

import lightning.pytorch as pl
from lightning.pytorch.loggers import CSVLogger
from lightning.pytorch.callbacks.early_stopping import EarlyStopping

seed = 1
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
pl.seed_everything(seed, workers=True)

In [None]:
batch_size = 64

train_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
])

test_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

train_dataset = datasets.ImageFolder(root="/kaggle/input/70-dog-breedsimage-data-set/train", transform=train_transform)
val_dataset = datasets.ImageFolder(root="/kaggle/input/70-dog-breedsimage-data-set/valid", transform=test_transform)
test_dataset = datasets.ImageFolder(root="/kaggle/input/70-dog-breedsimage-data-set/test", transform=test_transform)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [None]:
class_names = train_dataset.classes
class_count = [train_dataset.targets.count(i) for i in range(len(class_names))]
df = pd.DataFrame({"Class": class_names, "Count": class_count})

plt.figure(figsize=(12, 20), dpi=100)
sns.barplot(x="Count", y="Class", data=df)
plt.tight_layout()
plt.show()

In [None]:
plt.figure(figsize=(12, 20), dpi=100)
images, labels = next(iter(val_loader))
for i in range(8):
    ax = plt.subplot(8, 4, i + 1)
    plt.imshow(images[i].permute(1, 2, 0).numpy())
    plt.title(class_names[labels[i]])
    plt.axis("off")
plt.tight_layout()
plt.show()

In [None]:
class LitModel(pl.LightningModule):
    def __init__(self, num_classes=1000):
        super().__init__()
        self.model = models.resnet18(weights="IMAGENET1K_V1")
#         for param in self.model.parameters():
#             param.requires_grad = False
        self.model.fc = nn.Linear(self.model.fc.in_features, num_classes)
        self.accuracy = torchmetrics.Accuracy(task="multiclass", num_classes=num_classes)
        self.precision = torchmetrics.Precision(task="multiclass", average="macro", num_classes=num_classes)
        self.recall = torchmetrics.Recall(task="multiclass", average="macro", num_classes=num_classes)
        self.f1score = torchmetrics.F1Score(task="multiclass", num_classes=num_classes)

    def forward(self, x):
        x = self.model(x)
        return x

    def configure_optimizers(self):
        optimizer = optim.Adam(
            self.parameters(), lr=0.001, betas=(0.9, 0.99), eps=1e-08, weight_decay=1e-5
        )
        return optimizer

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        self.log("train_loss", loss, on_step=True, on_epoch=False, prog_bar=True, logger=True)
        self.log_dict(
            {
                "train_acc": self.accuracy(y_hat, y),
                "train_prec": self.precision(y_hat, y),
                "train_recall": self.recall(y_hat, y),
                "train_f1score": self.f1score(y_hat, y),
            },
            on_step=True,
            on_epoch=False,
            logger=True,
        )
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        self.log("val_loss", loss, on_step=False, on_epoch=True, logger=True)
        self.log_dict(
            {
                "val_acc": self.accuracy(y_hat, y),
                "val_prec": self.precision(y_hat, y),
                "val_recall": self.recall(y_hat, y),
                "val_f1score": self.f1score(y_hat, y),
            },
            on_step=False,
            on_epoch=True,
            logger=True,
        )

    def test_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        self.log_dict(
            {
                "test_acc": self.accuracy(y_hat, y),
                "test_prec": self.precision(y_hat, y),
                "test_recall": self.recall(y_hat, y),
                "test_f1score": self.f1score(y_hat, y),
            }
        )

    def predict_step(self, batch, batch_idx, dataloader_idx=None):
        x, y = batch
        y_hat = self(x)
        preds = torch.argmax(y_hat, dim=1)
        return preds

In [None]:
num_classes = len(class_names)
model = LitModel(num_classes=num_classes)
logger = CSVLogger("./")
early_stop_callback = EarlyStopping(
    monitor="val_loss", min_delta=0.00, patience=5, verbose=False, mode="min"
)
trainer = pl.Trainer(
    max_epochs=20,
    enable_progress_bar=True,
    logger=logger,
    callbacks=[early_stop_callback],
    deterministic=True,
)
trainer.fit(model, train_loader, val_loader)