# Import Useful Packages

In [None]:
from torch.utils.data import Dataset, DataLoader
import torch
from transformers import BertTokenizer
from random import randrange
import torch.nn.functional as F
from torch import nn, einsum
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
from torchtext.data.functional import to_map_style_dataset
from torch.optim import AdamW
from torch.nn import CrossEntropyLoss
from torchviz import make_dot


import time
import pandas as pd
import matplotlib.pyplot as plt
import os
import numpy as np

from einops import rearrange, repeat
from einops.layers.torch import Rearrange, Reduce

from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score, classification_report, confusion_matrix


# Build model

## Spatial Gating Unit

In [None]:
class SpatialGatingUnit(nn.Module):
    def __init__(self, d_ffn, 
                 seq_len, weight_value=0.05):
        super().__init__()
        self.norm = nn.LayerNorm(d_ffn//2)
        
        # Setup weight for the spatial projection
        self.weight = nn.Parameter(torch.zeros(seq_len,seq_len))
        nn.init.uniform_(self.weight, a=-weight_value, b=weight_value)
        
        # Setup bias for the spatial projection
        self.bias = nn.Parameter(torch.ones(seq_len))

    def forward(self, x):
        u, v = x.chunk(2, dim=-1)
        v = self.norm(v)
        
        weight, bias = self.weight, self.bias
        v = einsum('b n d, m n -> b m d', v, weight) + rearrange(bias, 'n -> () n ()')
        return u * v

## Gated-MLP Block

In [None]:
class gMLPBlock(nn.Module):
    def __init__(self, d_model, d_ffn, seq_len):
        super().__init__()
        self.norm = nn.LayerNorm(d_model)
        self.channel_proj_U = nn.Sequential(
            nn.Linear(d_model, d_ffn),
            nn.GELU()
        )
        self.sgu = SpatialGatingUnit(d_ffn, seq_len)
        self.channel_proj_V = nn.Sequential(
            nn.Linear(d_ffn//2, d_model),
            nn.GELU()
        )
        
    def forward(self, x):
        res = x
        x = self.norm(x)
        x = self.channel_proj_U(x)
        x = self.sgu(x)
        x = self.channel_proj_V(x)
        return x + res

## Gated_MLP model with L Blocks

In [None]:
class gMLP(nn.Module):
    def __init__(self, d_model, d_ffn, seq_len, num_layers):
        super().__init__()
        self.layers = nn.ModuleList([])
        for _ in range(num_layers):
            self.layers.append(gMLPBlock(d_model, d_ffn, seq_len))
            
    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x

# Train and test function

## Train epoch

In [None]:
def train_epoch(model, optimizer, criterion, train_dataloader, device, epoch=0, log_interval=50):
    model.train()
    total_acc, total_count = 0, 0
    losses = []
    start_time = time.time()

    for idx, (inputs, labels) in enumerate(train_dataloader):
        inputs = inputs.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()

        predictions = model(inputs)

        loss = criterion(predictions, labels)
        losses.append(loss.item())

        loss.backward()
        optimizer.step()
        total_acc += (predictions.argmax(1) == labels).sum().item()
        total_count += labels.size(0)
        if idx % log_interval == 0 and idx > 0:
            elapsed = time.time() - start_time
            print(
                "| epoch {:3d} | {:5d}/{:5d} batches "
                "| accuracy {:8.3f}".format(
                    epoch, idx, len(train_dataloader), total_acc / total_count
                )
            )
            total_acc, total_count = 0, 0
            start_time = time.time()

    epoch_acc = total_acc / total_count
    epoch_loss = sum(losses) / len(losses)
    return epoch_acc, epoch_loss

## Evaluate epoch

In [None]:
def evaluate_epoch(model, criterion, test_dataloader, device):
    model.eval()
    total_acc, total_count = 0, 0
    losses = []

    with torch.no_grad():
        for idx, (inputs, labels) in enumerate(test_dataloader):
            inputs = inputs.to(device)
            labels = labels.to(device)

            predictions = model(inputs)

            loss = criterion(predictions, labels)
            losses.append(loss.item())

            total_acc += (predictions.argmax(1) == labels).sum().item()
            total_count += labels.size(0)

    epoch_acc = total_acc / total_count
    epoch_loss = sum(losses) / len(losses)
    return epoch_acc, epoch_loss

## Train

In [None]:
def train(model, model_name, save_model, optimizer, criterion,
          train_dataloader, test_dataloader, num_epochs, device,
          csv_filename):
    train_accs, train_losses = [], []
    eval_accs, eval_losses = [], []
    best_loss_eval = 100
    times = []

    for epoch in range(1, num_epochs+1):
        epoch_start_time = time.time()
        train_acc, train_loss = train_epoch(model, optimizer, criterion, train_dataloader, device, epoch)
        train_accs.append(train_acc)
        train_losses.append(train_loss)

        eval_acc, eval_loss = evaluate_epoch(model, criterion, test_dataloader, device)
        eval_accs.append(eval_acc)
        eval_losses.append(eval_loss)

        if eval_loss < best_loss_eval:
            torch.save(model.state_dict(), save_model + f'/{model_name}.pt')

        times.append(time.time() - epoch_start_time)
        print("-" * 59)
        print(
            "| End of epoch {:3d} | Time: {:5.2f}s | Train Accuracy {:8.3f} | Train Loss {:8.3f} "
            "| Valid Accuracy {:8.3f} | Valid Loss {:8.3f} ".format(
                epoch, time.time() - epoch_start_time, train_acc, train_loss, eval_acc, eval_loss
            )
        )
        print("-" * 59)

    model.load_state_dict(torch.load(save_model + f'/{model_name}.pt'))
    model.eval()

    # Create a DataFrame from the metrics
    metrics_df = pd.DataFrame({
        'Epoch': list(range(1, num_epochs+1)),
        'Train Accuracy': train_accs,
        'Train Loss': train_losses,
        'Valid Accuracy': eval_accs,
        'Valid Loss': eval_losses,
        'Time': times
    })

    # Write the DataFrame to a CSV file with a specified sheet name
    metrics_df.to_csv(csv_filename, index=False)

    return model, metrics_df

## Test

In [None]:
def test(model, device, test_loader, loss_fn):
    """Test loop."""
    model.eval()
    test_loss = 0
    correct = 0
    all_preds = []
    all_targets = []

    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += loss_fn(output, target.squeeze()).item()  # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

            all_preds.extend(pred.cpu().numpy().flatten())
            all_targets.extend(target.cpu().numpy().flatten())

    test_loss /= len(test_loader.dataset)

    accuracy = (100. * correct / len(test_loader.dataset))/100

    # Calculate additional metrics without using classification_report
    precision = precision_score(all_targets, all_preds, average='macro', zero_division=0)
    recall = recall_score(all_targets, all_preds, average='macro', zero_division=0)
    f1 = f1_score(all_targets, all_preds, average='macro', zero_division=0)

    # Round the metrics to two decimal places
    accuracy = round(accuracy, 2)
    precision = round(precision, 2)
    recall = round(recall, 2)
    f1 = round(f1, 2)

    # Return the metrics as a dictionary
    metrics_dict = {
        "Accuracy": accuracy,
        "F1 Score": f1,
        "Recall": recall,
        "Precision": precision
    }

    return metrics_dict

## Plot the result

In [None]:
def plot_result(num_epochs, train_accs, eval_accs, train_losses, eval_losses):
    epochs = list(range(num_epochs))
    fig, axs = plt.subplots(nrows = 1, ncols =2 , figsize = (12,6))
    axs[0].plot(epochs, train_accs, label = "Training")
    axs[0].plot(epochs, eval_accs, label = "Evaluation")
    axs[1].plot(epochs, train_losses, label = "Training")
    axs[1].plot(epochs, eval_losses, label = "Evaluation")
    axs[0].set_xlabel("Epochs")
    axs[1].set_xlabel("Epochs")
    axs[0].set_ylabel("Accuracy")
    axs[1].set_ylabel("Loss")
    plt.legend()

# Task

## Image Classification

### Data processing

In [None]:
import torch
from torchvision import datasets, transforms

# Define a transform to normalize the data
transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

# Download and load the training data
trainset = datasets.CIFAR10('./CIFAR10_data/', download=True, train=True, transform=transform)

# Download and load the test data
testset = datasets.CIFAR10('./CIFAR10_data/', download=True, train=False, transform=transform)

In [None]:
from torch.utils.data import DataLoader

# Define the batch size
batch_size = 32

# Create the dataloaders
trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True)
testloader = DataLoader(testset, batch_size=batch_size, shuffle=False)

### Image Classification model

Ok, thì cho ai chưa biết thì cái protocol của cái gMLP này là các tác giả sẽ xét như cái ViT/16 luôn, nhưng mà cái shape này có 32x32 à nên là chia 16x16 thì nó kì kì =)))))))))))))))). Nên là thay vào đó thì mình xét nó là ViT/8 (bịa đó), thì nó sẽ chia cái input từ 32 x 32 vào còn 16 cái 8 x 8. 

In [None]:
class PatchEmbedding(nn.Module):
    def __init__(self, patch_size, in_channels, embed_dim):
        super().__init__()
        self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        x = self.proj(x)  # (B, embed_dim, H, W)
        return x.flatten(2).transpose(1, 2)  # (B, num_patches, embed_dim)

Cái class mới này về cơ bản là cũng như cái ở trên thôi, nma cái ở trên nó có cái bước proj nữa nên là mình kh muốn phải return 2 kết quả (kiểu khó kiểm soát lúc code đoạn sau) nên mình tạo class mới y đúc class cũ thiếu mỗi quả projection.

In [None]:
class PatchExtractor(nn.Module):
    def __init__(self, patch_size):
        super().__init__()
        self.patch_size = patch_size

    def forward(self, x):
        B, C, H, W = x.shape
        # Split the image into patches
        x = x.unfold(2, self.patch_size, self.patch_size).unfold(3, self.patch_size, self.patch_size)
        x = x.permute(0, 2, 3, 1, 4, 5).reshape(B, -1, C, self.patch_size, self.patch_size)
        return x

# Initialize the PatchExtractor module
patch_extract = PatchExtractor(patch_size=8)

# Get one image from the train data
image, label = next(iter(trainloader))

# Select the first image in the batch
image = image[0].unsqueeze(0)  # Add an extra dimension for the batch size

# Apply the PatchExtractor module
patches = patch_extract(image)

# Convert the patches to numpy arrays and visualize them
patches = patches.numpy()
n_patches = patches.shape[1]
fig, axs = plt.subplots(1, n_patches, figsize=(n_patches*2, 2))
# Suppress warnings
for i, ax in enumerate(axs):
    patch = np.transpose(patches[0, i], (1, 2, 0))
    patch = np.clip(patch, 0, 1)  # Clip to the range [0, 1] to avoid overexposure
    ax.imshow(patch)
    ax.axis('off')
plt.show()

In [None]:
class gMLP_Vision(nn.Module):
    def __init__(self, patch_size = 8, num_patches = 16, embed_dim = 768, num_layers = 6):
        super().__init__()
        self.patch_embed = PatchEmbedding(patch_size=patch_size, in_channels=3, embed_dim=embed_dim)
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
        self.pos_drop = nn.Dropout(p=0.1)

        self.blocks = gMLP(d_model=embed_dim,
                           d_ffn=embed_dim*4, 
                           seq_len=num_patches+1, 
                           num_layers=num_layers)

        self.norm = nn.LayerNorm(embed_dim)

    def forward(self, x):
        B, C, H, W = x.shape
        x = self.patch_embed(x)

        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        x = x + self.pos_embed
        x = self.pos_drop(x)

        x = self.blocks(x)

        x = self.norm(x)

        return x[:, 0]

In [None]:
class Image_Classification(nn.Module):
    def __init__(self, patch_size = 8, num_patches = 16, embed_dim = 512, 
                 num_layers = 3, fc_dim = 256, num_classes = 10):
        super().__init__()
        self.model = gMLP_Vision(patch_size=patch_size, num_patches=num_patches, embed_dim=embed_dim, num_layers=num_layers)
        self.fc_1 = nn.Linear(embed_dim, fc_dim)
        self.act = nn.GELU()
        self.head = nn.Linear(fc_dim, num_classes)

    def forward(self, x):
        x = self.model(x)
        x = self.fc_1(x)
        x = self.act(x)
        x = self.head(x)
        return x

In [None]:
model = Image_Classification()
def count_parameters(model):
  return sum(p.numel() for p in model.parameters())
count_parameters(model)

In [None]:
optimizer = AdamW(model.parameters(), 
                  lr=1e-4)
criterion = CrossEntropyLoss()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

In [None]:
torch.manual_seed(42)

num_epochs = 10
save_model = './model'
os.makedirs(save_model, exist_ok = True)
model_name = 'Image_Classification'

model, metrics = train(
    model, model_name, save_model, optimizer, criterion,
    trainloader, testloader, num_epochs, device, 'Image_Classification.csv'
)

save_path = os.path.join(save_model, f"{model_name}.pt")
torch.save(model.state_dict(), save_path)
print(f"Model weights saved to {save_path}")

=)))))))))))))))) ok nó học được, nma còn hơi cùi, lúc đầu mình set model này cũng lớn 26M tham số nma mình set lại còn 4M tham số cho cái máy mình chạy được, chứ 26M thì nó ngủm đó ae, nma 26M thì vẫn là nhãi con so với big bro gu gồ nên thôi, chủ yếu là chạy cho vui, ae nào máy mạnh mạnh thì cứ mạnh dạn set cái model này bự lên, dù sao thì người ta cũng show ra là có model capacity càng lớn thì model chạy càng chính xác. 

Sau khi chạy xong thì có vẻ cái top1 hơi khoai cho thằng em, nên lấy cái top5 cho thằng e dễ thở 

In [None]:
def evaluate_epoch_topk(model, test_dataloader, device, k):
    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for idx, (inputs, labels) in enumerate(test_dataloader):
            inputs = inputs.to(device)
            labels = labels.to(device)

            predictions = model(inputs)

            _, predicted = predictions.topk(k, 1, True, True)
            total += labels.size(0)
            correct += (predicted == labels.view(-1, 1).expand_as(predicted)).sum().item()

    top5_acc = correct / total
    return top5_acc

top5 = evaluate_epoch_topk(model, testloader, device, 5)
print(f"Top-5 accuracy: {top5:.2%}")