## NNs

In [None]:
import torch
import torch.nn as nn

In [None]:
class FFN(nn.Module):
    def __init__(self, input_size: int = 28*28, output_size: int = 10):
        super().__init__()
        self.fc1 = nn.Linear(input_size, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, 128)
        self.fc4 = nn.Linear(128, output_size)
        self.relu = nn.ReLU()


    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x.reshape(x.size(0), -1)  # Flatten the input
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        x = self.relu(self.fc3(x))
        x = self.fc4(x)
        return x

In [None]:
class BasicBlock(nn.Module):
    def __init__(self, in_channels: int, out_channels: int, stride: int = 1):
        super().__init__()

        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels) 
        self.relu = nn.ReLU(inplace=True)

        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        
        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        identity = x 

        out = self.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        
        out = out + self.shortcut(identity)
        out = self.relu(out)
        return out
    

class ResNet18(nn.Module):
    def __init__(self, n_channels: int = 3, n_classes: int = 10):
        super().__init__()
        self.initial_conv = nn.Conv2d(n_channels, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)

        self.layer1 = self._make_layer(64, 64, num_blocks=2, stride=1)
        self.layer2 = self._make_layer(64, 128, num_blocks=2, stride=2)
        self.layer3 = self._make_layer(128, 256, num_blocks=2, stride=2)
        self.layer4 = self._make_layer(256, 512, num_blocks=2, stride=2)

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512, n_classes)

    def _make_layer(self, in_channels: int, out_channels: int, num_blocks: int, stride: int) -> nn.Sequential:
        layers = []
        layers.append(BasicBlock(in_channels, out_channels, stride))
        for _ in range(1, num_blocks):
            layers.append(BasicBlock(out_channels, out_channels))
        return nn.Sequential(*layers)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.relu(self.bn(self.initial_conv(x)))

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x

## Losses

In [None]:
class KDLoss(nn.Module):
    def __init__(self, teacher_model: nn.Module, alpha: float = 0.5, temperature: float = 3.0, class_weights=None):
        super().__init__()
        self.teacher = teacher_model
        self.teacher.eval()
        self.alpha = alpha
        self.T = temperature
        self.ce = nn.CrossEntropyLoss(weight=class_weights)
        self.kl = nn.KLDivLoss(reduction='batchmean')

    def forward(self, student_logits: torch.Tensor, labels: torch.Tensor, inputs: torch.Tensor) -> torch.Tensor:
        with torch.no_grad():
            teacher_logits = self.teacher(inputs)
        
        loss_ce = self.ce(student_logits, labels)
        student_soft = nn.functional.log_softmax(student_logits / self.T, dim=1)
        teacher_soft = nn.functional.softmax(teacher_logits / self.T, dim=1)
        loss_kl = self.kl(student_soft, teacher_soft) * (self.T ** 2)
        
        return self.alpha * loss_ce + (1 - self.alpha) * loss_kl
    

class SoftMSELoss(nn.Module):
    def __init__(self, teacher_model: nn.Module):
        super().__init__()
        self.teacher = teacher_model
        self.mse = nn.MSELoss()

    def forward(self, student_logits: torch.Tensor, labels: torch.Tensor, inputs: torch.Tensor) -> torch.Tensor:
        with torch.no_grad():
            teacher_logits = self.teacher(inputs)

        student_soft = nn.functional.log_softmax(student_logits, dim=1)
        teacher_soft = nn.functional.softmax(teacher_logits, dim=1)
        return self.mse(student_soft, teacher_soft)

## Utils

In [None]:
import numpy as np
from typing import Literal
from torchvision import datasets, transforms
from torch.utils.data import random_split, Dataset, Subset


def get_dataset(name: Literal['MNIST', 'CIFAR10']):
    transform = transforms.Compose([transforms.ToTensor()])

    if name == 'MNIST':
        X_train = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
        X_train, X_val = random_split(X_train, [50000, 10000], generator=torch.Generator().manual_seed(42))
        X_test = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
    elif name == 'CIFAR10':
        X_train = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
        X_train, X_val = random_split(X_train, [40000, 10000], generator=torch.Generator().manual_seed(42))
        X_test = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
    else:
        raise ValueError(f"Unsupported dataset: {name}")
    
    return X_train, X_val, X_test

def downsample_dataset(dataset: Dataset, class_label: int, fraction: float) -> Dataset:
    indices = []
    labels = []
    
    for idx in range(len(dataset)):
        _, label = dataset[idx]
        indices.append(idx)
        labels.append(label)
    
    labels = np.array(labels)
    indices = np.array(indices)
    
    downsample_mask = labels == class_label
    keep_indices = indices[~downsample_mask].tolist()
    downsample_indices = indices[downsample_mask]
    
    n_samples = int(len(downsample_indices) * fraction)
    print(f"Downsampling class {class_label} from {len(downsample_indices)} to {n_samples} samples.")
    np.random.seed(42)
    sampled_indices = np.random.choice(downsample_indices, size=n_samples, replace=False).tolist()
    
    final_indices = keep_indices + sampled_indices
    
    return Subset(dataset, final_indices)


def calculate_class_weights(dataset: Dataset, num_classes: int = 10) -> torch.Tensor:
    """Calculate class weights based on the full dataset. For balanced metrics"""
    class_counts = torch.zeros(num_classes)
    
    for _, label in dataset:
        class_counts[label] += 1
    
    class_weights = 1.0 / (class_counts + 1e-6)  
    class_weights = class_weights / class_weights.sum() * num_classes  
    
    return class_weights

In [None]:
from sklearn.metrics import balanced_accuracy_score, confusion_matrix, precision_recall_fscore_support
import matplotlib.pyplot as plt

def evaluate_model(model: nn.Module, dataset: Dataset, device: torch.device) -> dict:    
    model.eval()
    loader = torch.utils.data.DataLoader(dataset, batch_size=128, shuffle=False)
    
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for inputs, labels in loader:
            inputs = inputs.to(device)
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.numpy())
    
    all_preds = np.array(all_preds)
    all_labels = np.array(all_labels)
    
    balanced_acc = balanced_accuracy_score(all_labels, all_preds)
    conf_matrix = confusion_matrix(all_labels, all_preds)
    
    # Per-class metrics
    precision_per_class, recall_per_class, f1_per_class, support = precision_recall_fscore_support(
        all_labels, all_preds, average=None
    )
    
    # Macro-averaged metrics
    precision_macro, recall_macro, f1_macro, _ = precision_recall_fscore_support(
        all_labels, all_preds, average='macro'
    )
    
    return {
        'balanced_accuracy': balanced_acc,
        'confusion_matrix': conf_matrix,
        'precision_per_class': precision_per_class,
        'recall_per_class': recall_per_class,
        'f1_per_class': f1_per_class,
        'support_per_class': support,
        'precision_macro': precision_macro,
        'recall_macro': recall_macro,
        'f1_macro': f1_macro
    }

def pprint_eval(eval_dict: dict):
    print(f"Balanced Accuracy: {eval_dict['balanced_accuracy']:.4f}")
    print("Per-class Metrics:")
    for i in range(len(eval_dict['precision_per_class'])):
        print(f" Class {i}: Precision: {eval_dict['precision_per_class'][i]:.4f}, "
              f"Recall: {eval_dict['recall_per_class'][i]:.4f}, "
              f"F1-Score: {eval_dict['f1_per_class'][i]:.4f}, "
              f"Support: {eval_dict['support_per_class'][i]}")
    print(f"Macro-Averaged Metrics: Precision: {eval_dict['precision_macro']:.4f}, "
          f"Recall: {eval_dict['recall_macro']:.4f}, "
          f"F1-Score: {eval_dict['f1_macro']:.4f}"
          )
    
    # add numbers to confusion matrix
    plt.imshow(eval['confusion_matrix'], cmap='Blues')
    for i in range(eval['confusion_matrix'].shape[0]):
        for j in range(eval['confusion_matrix'].shape[1]):
            plt.text(j, i, eval['confusion_matrix'][i, j], ha='center', va='center', color='black')
    plt.colorbar()
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.title('Confusion Matrix')
    plt.xticks(ticks=range(10))
    plt.yticks(ticks=range(10))
    plt.show()


In [None]:
from tqdm import tqdm

def train_model(model: nn.Module, 
                train_data: torch.utils.data.Dataset, 
                val_data: torch.utils.data.Dataset,
                batch_size: int,
                learning_rate: float,
                num_epochs: int,
                device: torch.device,
                criterion: nn.Module = nn.CrossEntropyLoss(),
):
    
    train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True)
    val_loader = torch.utils.data.DataLoader(val_data, batch_size=batch_size, shuffle=False)

    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

    best_val_accuracy = 0.0
    metrics = {"val_accuracy": [], "train_loss": [], "val_loss": []}

    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        train_pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Train]")
        for inputs, labels in train_pbar:
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(inputs)
            if isinstance(criterion, KDLoss):
                loss = criterion(outputs, labels, inputs)
            else:
                loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * inputs.size(0)

            train_pbar.set_postfix({"loss": f"{loss.item():.4f}"})

        epoch_loss = running_loss / len(train_loader.dataset)

        model.eval()
        val_loss = 0.0
        correct_val = 0
        val_criterion = nn.CrossEntropyLoss()
        val_pbar = tqdm(val_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Val]")
        with torch.no_grad():
            for inputs, labels in val_pbar:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                loss = val_criterion(outputs, labels)
                val_loss += loss.item() * inputs.size(0)
                _, preds = torch.max(outputs, 1)
                correct_val += (preds == labels).sum().item()

        val_loss /= len(val_loader.dataset)
        val_accuracy = correct_val / len(val_loader.dataset)

        if val_accuracy > best_val_accuracy:
            best_val_accuracy = val_accuracy
            torch.save(model.state_dict(), 'best_model.pth')

        print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {epoch_loss:.4f}, Val Loss: {val_loss:.4f}, Val Acc: {val_accuracy:.4f}")
        metrics["train_loss"].append(epoch_loss)
        metrics["val_loss"].append(val_loss)
        metrics["val_accuracy"].append(val_accuracy)

    return model, metrics

## Run

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

In [None]:
teacher_model = FFN(output_size=10).to(device)
# teacher_model = ResNet18(n_classes=10).to(device)

In [None]:
X_train, X_val, X_test = get_dataset('MNIST')
# X_train, X_val, X_test = get_dataset('CIFAR10')

In [None]:
batch_size = 128
learning_rate = 0.001
num_epochs = 3

teacher_model, metrics = train_model(
    model=teacher_model,
    train_data=X_train,
    val_data=X_val,
    batch_size=batch_size,
    learning_rate=learning_rate,
    num_epochs=num_epochs,
    device=device
)

In [None]:
student_model = FFN(output_size=10).to(device)
# student_model = ResNet18(n_channels=1, n_classes=10).to(device)

In [None]:
print(f"Size before downsampling: {len(X_train)}")
X_train_downsampled = downsample_dataset(X_train, class_label=0, fraction=0.001)
print(f"Size after downsampling class: {len(X_train_downsampled)}")

In [None]:
# class_weights = calculate_class_weights(X_train_downsampled, num_classes=10)
class_weights = None
student_criterion = KDLoss(teacher_model=teacher_model, alpha=0.5, temperature=3.0, class_weights=class_weights)

student_model, metrics = train_model(
    model=student_model,
    train_data=X_train_downsampled,
    val_data=X_val,
    batch_size=batch_size,
    learning_rate=learning_rate,
    num_epochs=num_epochs,
    device=device,
    criterion=student_criterion
)

In [None]:
eval = evaluate_model(student_model, X_test, device)
pprint_eval(eval)

In [None]:
"""
Pytania Badawcze

1. Redukcja klas treningowych:
   - Jaka jest maksymalna liczba klas, które można usunąć jednocześnie bez degradacji wyników?
   - Jak factor redukcji wpływa na jakość modelu?

2. Wpływ hiperparametrów Knowledge Distillation:
   - Jak parametr alpha wpływa na jakosc modelu?
   - Jaki jest optymalny współczynnik temperatury (T)?

3. Wpływ architektury:
   - Czy różne architektury (FFN vs ResNet) reagują podobnie na knowledge distillation?
   - Czy złożoność modelu wpływa na efektywność transferu wiedzy?

4. Porównanie funkcji straty:
   - Cross Entropy vs KL Divergence + Cross Entropy vs MSE z soft labels
   - Zbilansowane metryki vs standardowe (dla imbalanced data)

5. Generalizacja metody:
   - Czy KD działa dla różnych architektur (FFN → ResNet, ResNet → FFN)?
   - Czy zadziała dla różnych zbiorów danych (MNIST vs CIFAR-10)?
"""