In [3]:
import copy
import numpy as np
import random
import timm
import torch
import torch.nn as nn

from PIL import Image
from pathlib import Path
from sklearn.utils.class_weight import compute_class_weight
from sklearn.metrics import f1_score
from torch.nn.functional import mse_loss
from torch.utils.data import DataLoader, Subset, random_split
from torchvision import datasets, transforms
from tqdm import tqdm

In [62]:
MANUAL_SEED = 27
NUMBER_OF_CLASSES = 4
TRAIN_VERSION = 5

RAW_DATA_DIR = "../training_data"

NORMALIZE_MEAN = [0.485, 0.456, 0.406]
NOMRALIZE_STD = [0.229, 0.224, 0.225]

In [63]:
transform_self_base = transforms.Compose([
    transforms.Resize((224, 224)),
])

transform_self_aug = transforms.Compose([
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomVerticalFlip(p=0.3),
    transforms.RandomRotation(15),
    transforms.ColorJitter(
        brightness=0.3,
        contrast=0.3,
        saturation=0.3,
        hue=0.05
    ),
    transforms.RandomResizedCrop((224, 224), scale=(0.8, 1.0)),
])

In [64]:
base_dataset = datasets.ImageFolder(root=RAW_DATA_DIR)

train_size = int(0.8*len(base_dataset))
validation_size = len(base_dataset)-train_size

generator = torch.Generator().manual_seed(MANUAL_SEED)
train_dataset, validation_dataset = random_split(base_dataset, [train_size, validation_size], generator=generator)

In [65]:
def load_class_dataset(target_class, num_variants):
    class_index = train_dataset.dataset.class_to_idx[target_class]
    indices = [i for i, (_, label) in enumerate(train_dataset.dataset.samples) if label == class_index]
    train_subset = Subset(train_dataset.dataset, indices)

    index_to_class = {v: k for k, v in train_dataset.dataset.class_to_idx.items()}

    class_images = []
    for i in range(len(train_subset)):
        image, label = train_subset[i]
        
        base_image = transform_self_base(image)
        class_images.append((base_image, label))

        class_name = index_to_class[label]
        if class_name == "sigatoka":
            continue

        for _ in range(num_variants):
            augmented_image = transform_self_aug(image)
            class_images.append((augmented_image, label))

    return class_images

In [66]:
classes = ["sigatoka", "cordana", "healthy", "pestalotiopsis"]
augemented_train_dataset = []
for class_name in classes:
    class_dataset = load_class_dataset(class_name, 4)
    augemented_train_dataset += class_dataset

In [67]:
weak_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=NORMALIZE_MEAN, 
        std=NOMRALIZE_STD
    ),
])

strong_transform = transforms.Compose([
    transforms.RandomResizedCrop(224, scale=(0.7, 1.0)),
    transforms.ColorJitter(),
    transforms.ToTensor(),
    transforms.RandomErasing(p=0.5),
    transforms.Normalize(
        mean=NORMALIZE_MEAN, 
        std=NOMRALIZE_STD
    ),
])

validation_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(
        mean=NORMALIZE_MEAN, 
        std=NOMRALIZE_STD
    ),
])

In [68]:
class DualTransformDataset(torch.utils.data.Dataset):
    def __init__(self, dataset, weak_transform, strong_transform):
        self.dataset = dataset
        self.weak_transform = weak_transform
        self.strong_transform = strong_transform

    def __getitem__(self, idx):
        image, label = self.dataset[idx]
        return self.weak_transform(image), self.strong_transform(image), label

    def __len__(self):
        return len(self.dataset)

augemented_train_dataset = DualTransformDataset(augemented_train_dataset, weak_transform, strong_transform)
validation_dataset = [(validation_transform(image), label) for image, label in validation_dataset]

train_loader = DataLoader(augemented_train_dataset, batch_size=32, shuffle=True)
validation_loader = DataLoader(validation_dataset, batch_size=32, shuffle=False)

In [69]:
random.seed(MANUAL_SEED)
np.random.seed(MANUAL_SEED)
torch.manual_seed(MANUAL_SEED)
torch.cuda.manual_seed_all(MANUAL_SEED)

torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [70]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

targets = [label for _, _, label in augemented_train_dataset]
class_weights = compute_class_weight("balanced", classes=np.unique(targets), y=targets)
class_weights = torch.tensor(class_weights, dtype=torch.float).to(device)

criterion = nn.CrossEntropyLoss(label_smoothing=0.1, weight=class_weights)

In [71]:
student = timm.create_model("vit_base_patch16_224", pretrained=True, num_classes=NUMBER_OF_CLASSES)
teacher = copy.deepcopy(student)

student = student.to(device)
teacher = teacher.to(device)

In [72]:
def update_teacher(student, teacher, ema_decay=0.99):
    for student_param, teacher_param in zip(student.parameters(), teacher.parameters()):
        teacher_param.data = ema_decay*teacher_param.data + (1-ema_decay)*student_param.data

In [74]:
optimizer = torch.optim.Adam(student.parameters(), lr=3e-4, weight_decay=1e-4)
lambda_weight = 0.5
num_epochs = 30
best_val_acc = 0.0

for epoch in range(num_epochs):
    student.train()
    teacher.eval()
    running_loss = 0.0
    total, correct = 0, 0
    loop = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}")

    for weak_img, strong_img, labels in loop:
        weak_img, strong_img, labels = weak_img.to(device), strong_img.to(device), labels.to(device)

        with torch.no_grad():
            teacher_output = teacher(weak_img)

        student_output = student(strong_img)
        loss_ce = criterion(student_output, labels)
        loss_consistency = mse_loss(student_output, teacher_output.detach())
        loss = loss_ce + lambda_weight * loss_consistency

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        update_teacher(student, teacher)

        running_loss += loss.item()
        _, predicted = torch.max(student_output.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        loop.set_postfix(loss=running_loss / (total // 32), accuracy=100. * correct / total)
        
    student.eval()
    val_loss = 0.0
    val_correct = 0
    val_total = 0
    true_labels = []
    predicted_labels = []

    with torch.no_grad():
        for image, labels in validation_loader:
            image, labels = image.to(device), labels.to(device)
            outputs = student(image)
            loss = criterion(outputs, labels)

            val_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            val_total += labels.size(0)
            val_correct += (predicted == labels).sum().item()

            true_labels.extend(labels.cpu().numpy())
            predicted_labels.extend(predicted.cpu().numpy())

    avg_val_loss = val_loss / len(validation_loader)
    val_accuracy = 100. * val_correct / val_total

    f1 = f1_score(true_labels, predicted_labels, average='macro')

    print(f"Validation Loss: {avg_val_loss:.4f} | Accuracy: {val_accuracy:.2f}% | F1 Score (macro): {f1:.4f}")
    
    log_path = Path("../models/model1") / f"version_{TRAIN_VERSION}" / "training_log.txt"
    log_path.parent.mkdir(parents=True, exist_ok=True)
    with open(log_path, "a") as f:
        f.write(f"Epoch {epoch+1}: Val Loss = {avg_val_loss:.4f}, Accuracy = {val_accuracy:.2f}%, F1 = {f1:.4f}\n")

    if val_accuracy > best_val_acc:
        best_val_acc = val_accuracy

        save_path = Path("../models/model1/") / f"version_{TRAIN_VERSION}"
        save_path.mkdir(parents=True, exist_ok=True)

        torch.save(student.state_dict(), save_path / "student.pth")
        torch.save(teacher.state_dict(), save_path / "teacher.pth")
        print(f"💾 Saved new best model at Epoch {epoch+1}")

Epoch 1/30: 100%|██████████| 79/79 [17:40<00:00, 13.43s/it, accuracy=47, loss=1.44]  


Validation Loss: 0.8030 | Accuracy: 71.60% | F1 Score (macro): 0.6281
💾 Saved new best model at Epoch 1


Epoch 2/30: 100%|██████████| 79/79 [18:56<00:00, 14.38s/it, accuracy=81.2, loss=0.93] 


Validation Loss: 0.5994 | Accuracy: 89.94% | F1 Score (macro): 0.8818
💾 Saved new best model at Epoch 2


Epoch 3/30: 100%|██████████| 79/79 [19:58<00:00, 15.17s/it, accuracy=89, loss=0.767]  


Validation Loss: 0.5489 | Accuracy: 89.94% | F1 Score (macro): 0.8820


Epoch 4/30: 100%|██████████| 79/79 [15:27<00:00, 11.73s/it, accuracy=90.8, loss=0.675]


Validation Loss: 0.5709 | Accuracy: 94.08% | F1 Score (macro): 0.9453
💾 Saved new best model at Epoch 4


Epoch 5/30: 100%|██████████| 79/79 [16:01<00:00, 12.17s/it, accuracy=90.9, loss=0.672]


Validation Loss: 0.4178 | Accuracy: 94.67% | F1 Score (macro): 0.9445
💾 Saved new best model at Epoch 5


Epoch 6/30: 100%|██████████| 79/79 [14:11<00:00, 10.77s/it, accuracy=92.6, loss=0.627]


Validation Loss: 0.4846 | Accuracy: 95.86% | F1 Score (macro): 0.9587
💾 Saved new best model at Epoch 6


Epoch 7/30: 100%|██████████| 79/79 [14:13<00:00, 10.81s/it, accuracy=93.3, loss=0.623]


Validation Loss: 0.4156 | Accuracy: 91.72% | F1 Score (macro): 0.9108


Epoch 8/30: 100%|██████████| 79/79 [13:42<00:00, 10.42s/it, accuracy=94.4, loss=0.606]


Validation Loss: 0.3836 | Accuracy: 97.04% | F1 Score (macro): 0.9609
💾 Saved new best model at Epoch 8


Epoch 9/30: 100%|██████████| 79/79 [14:31<00:00, 11.03s/it, accuracy=93.5, loss=0.639]


Validation Loss: 0.4939 | Accuracy: 95.86% | F1 Score (macro): 0.9506


Epoch 10/30: 100%|██████████| 79/79 [16:31<00:00, 12.56s/it, accuracy=94.8, loss=0.584]


Validation Loss: 0.3523 | Accuracy: 97.04% | F1 Score (macro): 0.9608


Epoch 11/30: 100%|██████████| 79/79 [16:42<00:00, 12.69s/it, accuracy=95.5, loss=0.573]


Validation Loss: 0.4376 | Accuracy: 97.63% | F1 Score (macro): 0.9781
💾 Saved new best model at Epoch 11


Epoch 12/30: 100%|██████████| 79/79 [16:39<00:00, 12.65s/it, accuracy=95, loss=0.576]  


Validation Loss: 0.3580 | Accuracy: 97.63% | F1 Score (macro): 0.9704


Epoch 13/30: 100%|██████████| 79/79 [16:40<00:00, 12.67s/it, accuracy=96.1, loss=0.541]


Validation Loss: 0.3775 | Accuracy: 97.63% | F1 Score (macro): 0.9692


Epoch 14/30: 100%|██████████| 79/79 [4:11:00<00:00, 190.64s/it, accuracy=96.3, loss=0.543]    


Validation Loss: 0.3201 | Accuracy: 100.00% | F1 Score (macro): 1.0000
💾 Saved new best model at Epoch 14


Epoch 15/30: 100%|██████████| 79/79 [14:14<00:00, 10.82s/it, accuracy=95.7, loss=0.56] 


Validation Loss: 0.3374 | Accuracy: 98.22% | F1 Score (macro): 0.9798


Epoch 16/30: 100%|██████████| 79/79 [13:25<00:00, 10.20s/it, accuracy=96.8, loss=0.535]


Validation Loss: 0.3396 | Accuracy: 98.82% | F1 Score (macro): 0.9881


Epoch 17/30: 100%|██████████| 79/79 [13:52<00:00, 10.54s/it, accuracy=96.6, loss=0.532]


Validation Loss: 0.3529 | Accuracy: 99.41% | F1 Score (macro): 0.9943


Epoch 18/30: 100%|██████████| 79/79 [15:39<00:00, 11.89s/it, accuracy=95.4, loss=0.617]


Validation Loss: 0.3940 | Accuracy: 95.27% | F1 Score (macro): 0.9515


Epoch 19/30: 100%|██████████| 79/79 [14:45<00:00, 11.21s/it, accuracy=94.4, loss=0.609]


Validation Loss: 0.3328 | Accuracy: 100.00% | F1 Score (macro): 1.0000


Epoch 20/30: 100%|██████████| 79/79 [14:28<00:00, 11.00s/it, accuracy=96.9, loss=0.523]


Validation Loss: 0.3217 | Accuracy: 99.41% | F1 Score (macro): 0.9947


Epoch 21/30: 100%|██████████| 79/79 [14:22<00:00, 10.92s/it, accuracy=95.1, loss=0.604]


Validation Loss: 0.3216 | Accuracy: 99.41% | F1 Score (macro): 0.9947


Epoch 22/30: 100%|██████████| 79/79 [14:21<00:00, 10.91s/it, accuracy=96.7, loss=0.543]


Validation Loss: 0.3538 | Accuracy: 98.82% | F1 Score (macro): 0.9892


Epoch 23/30: 100%|██████████| 79/79 [16:13<00:00, 12.33s/it, accuracy=96.3, loss=0.546]


Validation Loss: 0.3428 | Accuracy: 98.22% | F1 Score (macro): 0.9840


Epoch 24/30: 100%|██████████| 79/79 [15:57<00:00, 12.12s/it, accuracy=97.1, loss=0.539]


Validation Loss: 0.3568 | Accuracy: 98.22% | F1 Score (macro): 0.9834


Epoch 25/30: 100%|██████████| 79/79 [13:55<00:00, 10.58s/it, accuracy=97.3, loss=0.501]


Validation Loss: 0.3296 | Accuracy: 99.41% | F1 Score (macro): 0.9947


Epoch 26/30: 100%|██████████| 79/79 [12:59<00:00,  9.87s/it, accuracy=97.1, loss=0.495]


Validation Loss: 0.3350 | Accuracy: 99.41% | F1 Score (macro): 0.9947


Epoch 27/30: 100%|██████████| 79/79 [12:59<00:00,  9.87s/it, accuracy=97.7, loss=0.486]


Validation Loss: 0.3261 | Accuracy: 99.41% | F1 Score (macro): 0.9947


Epoch 28/30: 100%|██████████| 79/79 [12:52<00:00,  9.78s/it, accuracy=97.3, loss=0.511]


Validation Loss: 0.3165 | Accuracy: 99.41% | F1 Score (macro): 0.9921


Epoch 29/30: 100%|██████████| 79/79 [13:07<00:00,  9.96s/it, accuracy=98.5, loss=0.452]


Validation Loss: 0.3349 | Accuracy: 98.82% | F1 Score (macro): 0.9891


Epoch 30/30: 100%|██████████| 79/79 [15:54<00:00, 12.08s/it, accuracy=98.3, loss=0.462]


Validation Loss: 0.3635 | Accuracy: 97.63% | F1 Score (macro): 0.9704
