# 1. Install Dependecies



In [None]:
from os.path import exists

!pip install torchinfo
!nvidia-smi

if not exists('transformed_dataset.zip'):
    !wget -c --no-check-certificate "https://onedrive.live.com/download?cid=B6A8D8D812274310&resid=B6A8D8D812274310%211418&authkey=AItiimMPx0gCpQo" -O transformed_dataset.zip
    !unzip transformed_dataset

# 2. Data Setup

In [None]:
import os

from torchvision import datasets, transforms
from torch.utils.data import DataLoader

NUM_WORKERS = os.cpu_count()

def create_dataloaders(
    train_path: str,
    test_path: str,
    train_transform: transforms.Compose,
    test_transform: transforms.Compose,
    batch_size: int,
    num_workers: int = NUM_WORKERS
):
    train_data = datasets.ImageFolder(train_path, transform=train_transform)
    test_data = datasets.ImageFolder(test_path, transform=test_transform)

    class_names = train_data.classes

    train_dataloader = DataLoader(
        train_data,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=True,
    )
    test_dataloader = DataLoader(
        test_data,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True,
    )

    return train_dataloader, test_dataloader, class_names

# 2.1 SAVING

In [None]:
import torch
from pathlib import Path

def save_model(model: torch.nn.Module,
               target_path: str,
               model_name: str):

    target_dir_path = Path(target_path)
    target_dir_path.mkdir(parents=True,
                          exist_ok=True)

    assert model_name.endswith(".pth") or model_name.endswith(
        ".pt"), "model_name should end with '.pt' or '.pth'"
    model_save_path = target_dir_path / model_name

    torch.save(obj=model.state_dict(), f=model_save_path)

# 3. Test & Training Loops

In [None]:
import torch

from tqdm.auto import tqdm
from typing import Dict, List, Tuple


def train_step(model: torch.nn.Module,
               dataloader: torch.utils.data.DataLoader,
               loss_fn: torch.nn.Module,
               optimizer: torch.optim.Optimizer,
               device: torch.device) -> Tuple[float, float]:
    model.train()

    train_loss, train_accuracy = 0, 0

    for _, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)

        y_pred = model(X)

        loss = loss_fn(y_pred, y)
        train_loss += loss.item()

        optimizer.zero_grad()

        loss.backward()

        optimizer.step()

        y_pred_class = torch.argmax(torch.softmax(y_pred, dim=1), dim=1)
        train_accuracy += (y_pred_class == y).sum().item()/len(y_pred)

    train_loss = train_loss / len(dataloader)
    train_accuracy = train_accuracy / len(dataloader)
    return train_loss, train_accuracy


def test_step(model: torch.nn.Module,
              dataloader: torch.utils.data.DataLoader,
              loss_fn: torch.nn.Module,
              device: torch.device) -> Tuple[float, float]:
    model.eval()

    test_loss, test_accuracy = 0, 0

    with torch.inference_mode():
        for _, (X, y) in enumerate(dataloader):
            X, y = X.to(device), y.to(device)

            test_pred_logits = model(X)

            loss = loss_fn(test_pred_logits, y)
            test_loss += loss.item()

            test_pred_labels = test_pred_logits.argmax(dim=1)
            test_accuracy += ((test_pred_labels == y).sum().item() /
                              len(test_pred_labels))

    test_loss = test_loss / len(dataloader)
    test_accuracy = test_accuracy / len(dataloader)
    return test_loss, test_accuracy


def train(model: torch.nn.Module,
          train_dataloader: torch.utils.data.DataLoader,
          test_dataloader: torch.utils.data.DataLoader,
          optimizer: torch.optim.Optimizer,
          loss_fn: torch.nn.Module,
          epochs: int,
          device: torch.device) -> Dict[str, List]:

    results = {"train_loss": [],
               "train_accuracy": [],
               "test_loss": [],
               "test_accuracy": []
               }

    model.to(device)


    for epoch in tqdm(range(epochs)):
        train_loss, train_accuracy = train_step(model=model,
                                                dataloader=train_dataloader,
                                                loss_fn=loss_fn,
                                                optimizer=optimizer,
                                                device=device)
        test_loss, test_accuracy = test_step(model=model,
                                             dataloader=test_dataloader,
                                             loss_fn=loss_fn,
                                             device=device)
        
        save_model(model, 'models', 'efficientnet_v2_s_384_none.pth')

        print(
            f"Epoch: {epoch+1} | "
            f"train_loss: {train_loss:.4f} | "
            f"train_acc: {train_accuracy:.4f} | "
            f"test_loss: {test_loss:.4f} | "
            f"test_acc: {test_accuracy:.4f} | "
        )

        results["train_loss"].append(train_loss)
        results["train_accuracy"].append(train_accuracy)
        results["test_loss"].append(test_loss)
        results["test_accuracy"].append(test_accuracy)

    return results


# 4. Main file

In [None]:
from torchvision import transforms
import os
import torch
import torchvision
import torchinfo

from timeit import default_timer as timer

NUM_EPOCHS = 12
BATCH_SIZE = 28
LEARNING_RATE = 1e-4
WEIGHT_DECAY = 1e-5

torch.manual_seed(42)
torch.cuda.manual_seed(42)

train_path = "transformed_dataset/train"
test_path = "transformed_dataset/test"

device = "cuda" if torch.cuda.is_available() else "cpu"

if __name__ == "__main__":

    train_transform = transforms.Compose([
        transforms.Resize((384, 384)),
        transforms.RandomHorizontalFlip(p=0.4),
        transforms.RandomApply(transforms=[
            transforms.RandomCrop(size=(384, 384)),
            transforms.TrivialAugmentWide(),
            transforms.RandomRotation(degrees=(0, 25))
            ], p=0.4),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225]),
    ])

    test_transform = transforms.Compose([
        transforms.Resize((384, 384)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])

    train_dataloader, test_dataloader, class_names = create_dataloaders(
        train_path=train_path,
        test_path=test_path,
        train_transform=train_transform,
        test_transform=test_transform,
        batch_size=BATCH_SIZE
    )

    weights = torchvision.models.EfficientNet_V2_S_Weights.DEFAULT
    model = torchvision.models.efficientnet_v2_s(weights=weights).to(device)

    output_shape = len(class_names)

    model.classifier = torch.nn.Sequential(
        torch.nn.Dropout(p=0.5, inplace=True),
        torch.nn.Linear(in_features=1280,
                        out_features=output_shape,
                        bias=True)).to(device)

    loss_fn = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)

    start_time = timer()

    # print(torchinfo.summary(model=model,
    #                         input_size=(128, 3, 384, 384),
    #                         verbose=0,
    #                         col_names=["input_size", "output_size",
    #                                    "num_params", "trainable"],
    #                         col_width=20,
    #                         row_settings=["var_names"]
    #                         ))

    train(model=model,
                 train_dataloader=train_dataloader,
                 test_dataloader=test_dataloader,
                 loss_fn=loss_fn,
                 optimizer=optimizer,
                 epochs=NUM_EPOCHS,
                 device=device)

    end_time = timer()

    print(f"Total training time: {end_time-start_time:.3f} seconds")

In [None]:
from google.colab import files
files.download("models/efficientnet_v2_s_384.pth")