In [None]:
import os
import torch
import torch.nn as nn
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
from sklearn.metrics import classification_report, accuracy_score, f1_score, precision_score, recall_score, confusion_matrix
from PIL import Image
import copy
import time
import matplotlib.pyplot as plt
from tqdm import tqdm
import seaborn as sns
import pandas as pd
import numpy as np
from models.model import Student, Teacher

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

In [None]:
class RiceSeedDataset(Dataset):

    def __init__(self, data, transform=None):
        self.data = data
        self.transform = transform

    def __getitem__(self, index):
        label = torch.tensor(self.data.loc[index, "Label"]).long()
        
        path = os.path.join(self.data.loc[index, "Path"])
        img = Image.open(path).convert("RGB")
        if self.transform:
            img = self.transform(img)
        return img, label

    def __len__(self):
        return len(self.data)

mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])

transform = {
    "Train": transforms.Compose([
        transforms.Resize((224, 224), interpolation=transforms.InterpolationMode.LANCZOS),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomVerticalFlip(p=0.5),
        transforms.RandomRotation(10),
        transforms.ToTensor(),
        transforms.Normalize(mean, std)
    ]),
    
    "Validation": transforms.Compose([
        transforms.Resize((224, 224), interpolation=transforms.InterpolationMode.LANCZOS),
        transforms.ToTensor(),
        transforms.Normalize(mean, std)
    ]),
    
    "Test": transforms.Compose([
        transforms.Resize((224, 224), interpolation=transforms.InterpolationMode.LANCZOS),
        transforms.ToTensor(),
        transforms.Normalize(mean, std)
    ]),
}

In [None]:
df = pd.read_csv(r"/kaggle/input/meta-rice-co-hong/meta_rice_seed.csv")
df

In [None]:
# 'BC15', 'HuongThom', 'Nep87', 'Q5', 'TBR36', 'TBR45', 'TH3', 'ThienUu8', 'Xi23'

name = "TH3"

sub_df = df.loc[df["RiceSeed"] == name].copy().reset_index(drop=True)
df_train = sub_df.loc[sub_df["Type"] == "Train"].copy().reset_index(drop=True)
df_val = sub_df.loc[sub_df["Type"] == "Validation"].copy().reset_index(drop=True)
df_test = sub_df.loc[sub_df["Type"] == "Test"].copy().reset_index(drop=True)

df_train.shape, df_val.shape, df_test.shape

In [None]:
folder_path = f''
os.makedirs(folder_path, exist_ok=True)

In [None]:
BATCH_SIZE = 64
LEARNING_RATE = 5e-6
EPOCHS = 100
TEMPERATURE1 = 2
TEMPERATURE2 = 2
ALPHA = 0.9
GAMMA = 0.05 
BETA = 0.05

result = {
    "Optimizer": "Adam",
    "Learning_rate": LEARNING_RATE,
    "Num_Epochs": EPOCHS,
    "TEMPERATURE1": TEMPERATURE1,
    "TEMPERATURE2": TEMPERATURE2,
    "Alpha": ALPHA,
    "Gamma": GAMMA,
    "Beta": BETA
}

classes = df_train["Label"].unique()

train_dataset = RiceSeedDataset(data=df_train, transform=transform["Train"])
valid_dataset = RiceSeedDataset(data=df_val, transform=transform["Validation"])
test_dataset = RiceSeedDataset(data=df_test, transform=transform["Test"])

train_dataloader = DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True)
valid_dataloader = DataLoader(dataset=valid_dataset, batch_size=BATCH_SIZE, shuffle=False)
test_dataloader = DataLoader(dataset=test_dataset, batch_size=BATCH_SIZE, shuffle=False)

data_loader = {"Train": train_dataloader, "Validation": valid_dataloader}

print(classes)

In [None]:
def cross_entropy_loss(logits, labels, criterion):
    losses = []
    predictions = []
    for logit in logits:
        loss = criterion(logit, labels)
        losses.append(loss)
        _, pred = torch.max(logit, dim=1)
        predictions.append(pred)
    
    return losses, predictions

def stage_wise_response_distillation(logit_students, logit_teacher, criterion):
    global TEMPERATURE1

    total_loss = 0.0
    
    for logit_s, logit_t in zip(logit_students, logit_teacher):
        prob_student = F.log_softmax(logit_s / TEMPERATURE1, dim=1)
        prob_teacher = F.softmax(logit_t / TEMPERATURE1, dim=1)

        loss = criterion(prob_student, prob_teacher) * (TEMPERATURE1**2)

        total_loss += loss

    return total_loss

def stage_wise_channel_distillation(fea_students, fea_teachers):
    total_dist = 0.0
    for fea_s, fea_t in zip(fea_students, fea_teachers):
        dist = F.pairwise_distance(fea_s.view(fea_s.size(0), -1),
                                   fea_t.view(fea_t.size(0), -1), p=2)
        total_dist += dist.mean()
    return total_dist

def cross_stage_review_distillation(logits1_student, logits2_student, criterion):
    global TEMPERATURE2
    
    def compute(logit_student, logit_teacher, temperature, criterion):
        prob_student = F.log_softmax(logit_student / temperature, dim=1)
        prob_teacher = F.softmax(logit_teacher / temperature, dim=1)
        loss = criterion(prob_student, prob_teacher) * (temperature**2)
        return loss
    
    total_loss = 0.0
    # 1 2
    total_loss += compute(logits2_student, logits1_student.detach(), TEMPERATURE2, criterion)

    return total_loss

def train_model(data_loader, model, criterion, optimizer, num_epochs, device, early_stop=True, patience=10):
    
    global ALPHA, GAMMA, BETA
    
    student, teacher = model
    student, teacher = student.to(device), teacher.to(device)

    criterion_ce, criterion_kl = criterion

    optimizer_student, optimizer_teacher = optimizer

    wait = 0
    best_teacher_wts = copy.deepcopy(teacher.state_dict())
    best_student_wts = copy.deepcopy(student.state_dict())
    best_val_loss_teacher, best_val_loss_student = float("inf"), float("inf")
    best_epoch_teacher, best_epoch_student = 0, 0

    history_teacher = {
        "Train_Loss1": [], "Train_Acc1": [], "Validation_Loss1": [], "Validation_Acc1": [],
        "Train_Loss2": [], "Train_Acc2": [], "Validation_Loss2": [], "Validation_Acc2": [],
        "Time": []
              }

    history_student = {
        "Train_Loss1": [], "Train_Acc1": [], "Validation_Loss1": [], "Validation_Acc1": [],
        "Train_Loss2": [], "Train_Acc2": [], "Validation_Loss2": [], "Validation_Acc2": [],
        "Time": []
              }

    since = time.time()

    for epoch in range(1, num_epochs + 1):
        print("-----------------------------------------------------------------------")
        print(f"Epoch {epoch}/{num_epochs}")
        epoch_start = time.time()
        
        for phase in ["Train", "Validation"]:
            if phase == "Train":
                teacher.train()
                student.train()
            else:
                teacher.eval()
                student.eval()

            running_loss_teacher = [0.0] * 2
            running_corrects_teacher = [0] * 2
            running_loss_student = [0.0] * 2
            running_corrects_student = [0] * 2
            total_samples = 0
            
            for images, labels in tqdm(data_loader[phase], desc=f"{phase}"):
                images, labels = images.to(device), labels.to(device)
                
                total_samples += images.size(0)
                
                fea1_t, logit1_t, fea2_t, logit2_t = None, None, None, None
                
                with torch.set_grad_enabled(phase == "Train"):
                    fea1_t, logit1_t, fea2_t, logit2_t = teacher(images)
                    
                    losses, predictions = cross_entropy_loss(logits=[logit1_t, logit2_t], 
                                                             labels=labels, 
                                                             criterion=criterion_ce)
                    total_loss_teacher = sum(losses)
                    
                    if phase == "Train":
                        optimizer_teacher.zero_grad()
                        total_loss_teacher.backward()
                        optimizer_teacher.step()

                    for i in range(2):
                        running_loss_teacher[i] += losses[i].item() * images.size(0)
                        running_corrects_teacher[i] += (predictions[i] == labels).sum().item()

                with torch.set_grad_enabled(phase == "Train"):
                    fea1_s, logit1_s, fea2_s, logit2_s = student(images)
                    
                    losses, predictions = cross_entropy_loss(logits=[logit1_s, logit2_s], 
                                                             labels=labels, 
                                                             criterion=criterion_ce)
                    L_CE = sum(losses)
                
                    L_SCD = stage_wise_channel_distillation(fea_students=[fea1_s, fea2_s], 
                                                            fea_teachers= [fea1_t.detach(), fea2_t.detach()])
                    
                    L_SRD = stage_wise_response_distillation(logit_students=[logit1_s, logit2_s], 
                                                             logit_teacher=[logit1_t.detach(), logit2_t.detach()],
                                                             criterion=criterion_kl)

                    L_CRD = cross_stage_review_distillation(logits1_student=logit1_s, 
                                                            logits2_student=logit2_s, 
                                                            criterion=criterion_kl)
                    
                    total_loss_student = (BETA * L_SCD) + (GAMMA * L_CRD) + ((1 - ALPHA) * L_SRD) + (ALPHA * L_CE)

                    if phase == "Train":
                        optimizer_student.zero_grad()
                        total_loss_student.backward()
                        optimizer_student.step()

                    for i in range(2):
                        running_loss_student[i] += losses[i].item() * images.size(0)
                        running_corrects_student[i] += (predictions[i] == labels).sum().item()
            
            for i in range(2):
                epoch_loss_teacher = running_loss_teacher[i] / total_samples
                epoch_acc_teacher = running_corrects_teacher[i] / total_samples

                epoch_loss_student = running_loss_student[i] / total_samples
                epoch_acc_student = running_corrects_student[i] / total_samples

                history_teacher[f"{phase}_Loss{i+1}"].append(epoch_loss_teacher)
                history_teacher[f"{phase}_Acc{i+1}"].append(epoch_acc_teacher)

                history_student[f"{phase}_Loss{i+1}"].append(epoch_loss_student)
                history_student[f"{phase}_Acc{i+1}"].append(epoch_acc_student)

            if phase == "Validation":
                epoch_loss_s = [loss / total_samples for loss in running_loss_student]
                final_branch_val_loss = epoch_loss_s[-1]
                
                if best_val_loss_student > final_branch_val_loss:
                    best_val_loss_student = final_branch_val_loss
                    best_student_wts = copy.deepcopy(student.state_dict())
                    best_epoch_student = epoch
                    wait = 0
                else:
                    wait += 1
                
                avg_val_loss = sum([history_teacher[f"Validation_Loss{i+1}"][-1] for i in range(2)]) / 2
                
                if best_val_loss_teacher > avg_val_loss:
                    best_val_loss_teacher = avg_val_loss
                    best_teacher_wts = copy.deepcopy(teacher.state_dict())
                    best_epoch_teacher = epoch

        epoch_duration = time.time() - epoch_start
        history_teacher["Time"].append(epoch_duration)
        history_student["Time"].append(epoch_duration)

        print("Teacher")
        for i in range(2):
            print(f"Branch {i+1} - Train Loss: {history_teacher[f'Train_Loss{i+1}'][-1]:.4f}, "
                  f"Train Acc: {history_teacher[f'Train_Acc{i+1}'][-1]:.4f} | "
                  f"Valid Loss: {history_teacher[f'Validation_Loss{i+1}'][-1]:.4f}, "
                  f"Valid Acc: {history_teacher[f'Validation_Acc{i+1}'][-1]:.4f}")
        
        print("Student")
        for i in range(2):
            print(f"Branch {i+1} - Train Loss: {history_student[f'Train_Loss{i+1}'][-1]:.4f}, "
                  f"Train Acc: {history_student[f'Train_Acc{i+1}'][-1]:.4f} | "
                  f"Valid Loss: {history_student[f'Validation_Loss{i+1}'][-1]:.4f}, "
                  f"Valid Acc: {history_student[f'Validation_Acc{i+1}'][-1]:.4f}")
        
        print(f"Epoch {epoch+1} finished in {epoch_duration:.2f}s")
        
        if early_stop and wait >= patience:
            print(f"Early stopping at epoch {epoch+1} (no improvement in {patience} epochs).")
            break

    print("-----------------------------------------------------------------------")
    time_elapsed = time.time() - since
    print(f"Training complete in {time_elapsed:.2f}s")
    
    student.load_state_dict(best_student_wts)
    teacher.load_state_dict(best_teacher_wts)
    model = (student, teacher)
    history = (pd.DataFrame(history_student), pd.DataFrame(history_teacher))
    best_val_loss = (best_val_loss_student, best_val_loss_teacher)
    best_epoch = (best_epoch_student, best_epoch_teacher)
    
    return model, history, time_elapsed, best_val_loss, best_epoch

In [None]:
student = Student(num_classes=len(classes), num_layers=[4, 6, 8, 10], growth_rate=16)
student.load_state_dict(torch.load("student_weights.pth"))

teacher = Teacher(len(classes))
teacher.load_state_dict(torch.load("teacher_weights.pth"))

criterion_ce = nn.CrossEntropyLoss()
criterion_kl = nn.KLDivLoss(reduction="batchmean")

optimizer_student = torch.optim.Adam(student.parameters(), lr=LEARNING_RATE)
optimizer_teacher = torch.optim.Adam(teacher.parameters(), lr=LEARNING_RATE)

In [None]:
model, history, time_elapse, best_val_loss, best_epoch = train_model(data_loader=data_loader, 
                                                                     model=(student, teacher), 
                                                                     criterion=(criterion_ce, criterion_kl), 
                                                                     optimizer=(optimizer_student, optimizer_teacher), 
                                                                     num_epochs=EPOCHS,
                                                                     device=device, 
                                                                     early_stop=True, 
                                                                     patience=20)

student, teacher = model
history_student, history_teacher = history

result["Best_val_loss_student"] = best_val_loss[0]
result["Best_val_loss_teacher"] = best_val_loss[1]

result["Best_epoch_student"] = best_epoch[0]
result["Best_epoch_teacher"] = best_epoch[1]

result["Time_elapse"] = time_elapse

In [None]:
# Teacher
fig, axs = plt.subplots(nrows=2, ncols=2, figsize=(15, 10))

for i in range(2):
    branch_id = i + 1

    # Accuracy plot
    axs[i, 0].plot(history_teacher[f'Train_Acc{branch_id}'], label='Train Accuracy')
    axs[i, 0].plot(history_teacher[f'Validation_Acc{branch_id}'], label='Validation Accuracy')
    axs[i, 0].set_xlabel('Epoch')
    axs[i, 0].set_ylabel('Accuracy')
    axs[i, 0].set_title(f'Branch {branch_id} - Accuracy')
    axs[i, 0].legend()

    # Loss plot
    axs[i, 1].plot(history_teacher[f'Train_Loss{branch_id}'], label='Train Loss')
    axs[i, 1].plot(history_teacher[f'Validation_Loss{branch_id}'], label='Validation Loss')
    axs[i, 1].set_xlabel('Epoch')
    axs[i, 1].set_ylabel('Loss')
    axs[i, 1].set_title(f'Branch {branch_id} - Loss')
    axs[i, 1].legend()

plt.tight_layout()
plt.show()

In [None]:
# Student
fig, axs = plt.subplots(nrows=2, ncols=2, figsize=(15, 10))

for i in range(2):
    branch_id = i + 1

    # Accuracy plot
    axs[i, 0].plot(history_student[f'Train_Acc{branch_id}'], label='Train Accuracy')
    axs[i, 0].plot(history_student[f'Validation_Acc{branch_id}'], label='Validation Accuracy')
    axs[i, 0].set_xlabel('Epoch')
    axs[i, 0].set_ylabel('Accuracy')
    axs[i, 0].set_title(f'Branch {branch_id} - Accuracy')
    axs[i, 0].legend()

    # Loss plot
    axs[i, 1].plot(history_student[f'Train_Loss{branch_id}'], label='Train Loss')
    axs[i, 1].plot(history_student[f'Validation_Loss{branch_id}'], label='Validation Loss')
    axs[i, 1].set_xlabel('Epoch')
    axs[i, 1].set_ylabel('Loss')
    axs[i, 1].set_title(f'Branch {branch_id} - Loss')
    axs[i, 1].legend()

plt.tight_layout()
plt.show()

In [None]:
# teacher
teacher.eval()

branch_outputs = {
    '1': {'correct': [], 'predict': []},
    '2': {'correct': [], 'predict': []},
}

with torch.no_grad():
    for images, labels in test_dataloader:
        images, labels = images.to(device), labels.to(device)
        
        fea1, logit1, fea2, logit2 = teacher(images)

        for i, outputs in enumerate([logit1, logit2], start=1):
            _, predicts = torch.max(outputs, dim=1)
            branch_outputs[str(i)]['correct'].extend(labels.cpu().numpy())
            branch_outputs[str(i)]['predict'].extend(predicts.cpu().numpy())

for i in range(1, 3):
    correct = branch_outputs[str(i)]['correct']
    predict = branch_outputs[str(i)]['predict']
    
    acc = accuracy_score(correct, predict)
    precision = precision_score(correct, predict, average='weighted')
    recall = recall_score(correct, predict, average='weighted')
    f1 = f1_score(correct, predict, average='weighted')

    result[f"Teacher_Branch_{i}_Acc"] = acc
    result[f"Teacher_Branch_{i}_Pre"] = precision
    result[f"Teacher_Branch_{i}_Rec"] = recall
    result[f"Teacher_Branch_{i}_F1"] = f1
    
    print(f"\n=== Branch {i} ===")
    print(f"Accuracy: {acc:.4f}")
    print(f"Precision: {precision:.4f}")
    print(f"Recall: {recall:.4f}")
    print(f"F1-Score: {f1:.4f}")

    labels = ["Negative", "Positive"]

    report = classification_report(correct, predict, target_names=labels, digits=4)

    path = os.path.join(folder_path, f"classification_report_teacher_branch_{i}")
    with open(path, 'w') as f:
        f.write(report)
    
    print(report)

    # Plot confusion matrix
    cm = confusion_matrix(correct, predict)
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                xticklabels=labels,
                yticklabels=labels)
    plt.xlabel('Predicted')
    plt.ylabel('Actual')
    plt.title(f'Confusion Matrix - Branch {i}')
    plt.tight_layout()

    plt.savefig(os.path.join(folder_path, f'confusion_matrix_teacher_branch_{i}.png'))
    
    plt.show()
    plt.close()

In [None]:
# student
student.eval()

branch_outputs = {
    '1': {'correct': [], 'predict': []},
    '2': {'correct': [], 'predict': []},
}

with torch.no_grad():
    for images, labels in test_dataloader:
        images, labels = images.to(device), labels.to(device)
        
        fea1, logit1, fea2, logit2 = student(images)

        for i, outputs in enumerate([logit1, logit2], start=1):
            _, predicts = torch.max(outputs, dim=1)
            branch_outputs[str(i)]['correct'].extend(labels.cpu().numpy())
            branch_outputs[str(i)]['predict'].extend(predicts.cpu().numpy())

for i in range(1, 3):
    correct = branch_outputs[str(i)]['correct']
    predict = branch_outputs[str(i)]['predict']
    
    acc = accuracy_score(correct, predict)
    precision = precision_score(correct, predict, average='weighted')
    recall = recall_score(correct, predict, average='weighted')
    f1 = f1_score(correct, predict, average='weighted')

    result[f"Student_Branch_{i}_Acc"] = acc
    result[f"Student_Branch_{i}_Pre"] = precision
    result[f"Student_Branch_{i}_Rec"] = recall
    result[f"Student_Branch_{i}_F1"] = f1
    
    print(f"\n=== Branch {i} ===")
    print(f"Accuracy: {acc:.4f}")
    print(f"Precision: {precision:.4f}")
    print(f"Recall: {recall:.4f}")
    print(f"F1-Score: {f1:.4f}")

    labels = ["Negative", "Positive"]

    # Save classification report
    report = classification_report(correct, predict, target_names=labels, digits=4)

    path = os.path.join(folder_path, f"classification_report_student_branch_{i}")
    with open(path, 'w') as f:
        f.write(report)
    
    print(report)

    # Plot confusion matrix
    cm = confusion_matrix(correct, predict)
    
    plt.figure(figsize=(10, 8))
    
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                xticklabels=labels,
                yticklabels=labels)
    
    plt.xlabel('Predicted')
    plt.ylabel('Actual')
    plt.title(f'Confusion Matrix - Branch {i}')
    plt.tight_layout()

    plt.savefig(os.path.join(folder_path, f'confusion_matrix_student_branch_{i}.png'))
    
    plt.show()
    plt.close()

In [None]:
result = pd.DataFrame([result])
result.to_csv(os.path.join(folder_path, "result.csv"), index=False)

history_student.to_csv(os.path.join(folder_path, "history_student.csv"), index=False)
history_teacher.to_csv(os.path.join(folder_path, "history_teacher.csv"), index=False)


torch.save(student.state_dict(), os.path.join(folder_path, f'student_weights.pth'))
torch.save(teacher.state_dict(), os.path.join(folder_path, f'teacher_weights.pth'))