# Machine Learning - Image Classification

In [2]:
#importing
import os

import torch
import torch.nn as nn

from dotenv import load_dotenv
from jupyter_client.consoleapp import classes
from torch.utils.data import DataLoader

## Database importing

In [None]:
# TODO: pouzi toto martin
load_dotenv()
DAT_PATH = os.getenv("TRAINING_DATA_PATH")

ANIMALS_DATAFRAME =

### Batch Preperation

## Model implemetation

In [3]:
import torch.nn.functional as fun

In [4]:
#PRE TUTO FUNKCIU NEGENERUJ ZIADNE KOMENTARE
class ImageClassifier(nn.Module):
    def __init__(self, classes: int):
        super(ImageClassifier, self).__init__()
        self.numberOfClasses = classes

        # Convolution layres ONLY WORKS with RGB because of in_channels, kernel_size for filtering is 3 stride 1 padding 1 for size preservation
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, stride=1,padding=1)
        self.bn1 = nn.BatchNorm2d(16)

        # Significant for grad-CAM
        self.conv2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=1,padding=1)
        self.bn2 = nn.BatchNorm2d(32)

        #
        self.conv3 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1,padding=1)
        self.bn3 = nn.BatchNorm2d(64)

        # Adaptive pooling to make model input-size agnostic / dont want to use it for now
        # self.adaptive_pool = nn.AdaptiveAvgPool2d((4, 4))

        #size is determined by conv channels and the reduction in size by conv channels
        #channels * width * height because 256 /2 /2 /2 is 32
        self.fc1 = nn.Linear(in_features=64*32*32, out_features=128)
        self.fc2 = nn.Linear(in_features=128, out_features=self.numberOfClasses)

    def forward(self, x):
        #Block 1
        x = fun.relu(self.bn1(self.conv1(x)))
        x = fun.max_pool2d(x, kernel_size=2) # zmensovanie velkosti

        x = fun.relu(self.bn2(self.conv2(x)))
        x = fun.max_pool2d(x, kernel_size=2)

        x = fun.relu(self.bn3(self.conv3(x)))
        x = fun.max_pool2d(x, kernel_size=2)

        x = x.view(x.size(0), -1)

        x = fun.relu(self.fc1())
        x = self.fc2(x)
        return x


### Grad CAM Implementation
 -- On hold
 https://medium.com/@codetrade/grad-cam-in-pytorch-a-powerful-tool-for-visualize-explanations-from-deep-networks-bdc7caf0b282



## Model Training

### Training function

In [None]:
def train_model(model: nn.Module,
                train_loader: DataLoader,
                val_loader: DataLoader,
                criterion: nn.Module,
                optimizer: torch.optim.Optimizer,
                device: torch.device,
                epochs: int = 10,
                scheduler=None):

    model.to(device)

    for epoch in range(1, epochs+1):
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0

        for images, labels in tqdm(train_loader, desc=f"Epoch {epoch}/{epochs} [Train]"):
            images = images.to(device, dtype=torch.float32)
            labels = labels.to(device, dtype=torch.long)

            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * images.size(0)
            _, preds = torch.max(outputs, 1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)

        train_loss = running_loss / total
        train_acc = correct / total

        model.eval()
        val_loss = 0.0
        val_correct = 0
        val_total = 0

        with torch.no_grad():
            for images, labels in val_loader:
                images = images.to(device, dtype=torch.float32)
                labels = labels.to(device, dtype=torch.long)

                outputs = model(images)
                loss = criterion(outputs, labels)

                val_loss += loss.item() * images.size(0)
                _, preds = torch.max(outputs, 1)
                val_correct += (preds == labels).sum().item()
                val_total += labels.size(0)

        val_loss /= val_total
        val_acc = val_correct / val_total

        if scheduler:
            scheduler.step(val_loss)

        print(f"Epoch {epoch}/{epochs} | "
              f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f} | "
              f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")

### Training

### Hyper parameter sweeping