In [1]:
%load_ext tensorboard


In [2]:
%tensorboard --logdir /kaggle/working/tensorboard_logs

<IPython.core.display.Javascript object>

In [3]:
import torch
import torch.nn as nn
from torchvision import transforms, datasets
from torch.utils.data import DataLoader
from torchvision.models import vit_b_16, ViT_B_16_Weights

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms, datasets, models
from torch.utils.data import DataLoader
from torchvision.models import vit_b_16, ViT_B_16_Weights
from typing import Dict
import math
from typing import Optional, List

import time
import copy
import numpy as np
import matplotlib.pyplot as plt

import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

from torch.utils.tensorboard import SummaryWriter

# Initialize TensorBoard writer
writer = SummaryWriter('/kaggle/working/tensorboard_logs')

#####################################
# Hyperparameters
#####################################
EPOCHS = 1
batch_size = 64            # Adjust as your GPU allows
BASE_LR = 1e-3
WEIGHT_DECAY = 0.03
DROPOUT = 0.1
R_LORA = 4              # Example LoRA rank
LORA_ALPHA = 32           # Example LoRA alpha
LORA_DROPOUT = 0



class LoRALayer():
    def __init__(
        self,
        r: int,
        lora_alpha: int,
        lora_dropout: float,
        merge_weights: bool,
    ):
        self.r = r
        self.lora_alpha = lora_alpha

        # Optional dropout
        if lora_dropout > 0.:
            self.lora_dropout = nn.Dropout(p=lora_dropout)
        else:
            self.lora_dropout = lambda x: x
        # Mark the weight as unmerged
        self.merged = False
        self.merge_weights = merge_weights


class xLinear(nn.Linear, LoRALayer):
    # LoRA implemented in a dense layer
    def __init__(
        self,
        in_features: int,
        out_features: int,
        r: int = 0,
        lora_alpha: int = 32,
        lora_dropout: float = 0.0,
        fan_in_fan_out: bool = False,
        merge_weights: bool = True,
        pretrained_weights=None,  # Added to accept pretrained weights
        pretrained_bias=None,     # Added to accept pretrained bias
        **kwargs
    ):
        super().__init__(in_features, out_features, **kwargs)
        LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout,
                           merge_weights=merge_weights)

        self.fan_in_fan_out = fan_in_fan_out
        if pretrained_weights is not None:
            self.weight.data = pretrained_weights
        if pretrained_bias is not None:
            self.bias.data = pretrained_bias

        # Actual trainable parameters
        if r > 0:
            self.lora_A = nn.Parameter(self.weight.new_zeros((r, in_features)))
            self.lora_B = nn.Parameter(self.weight.new_zeros((out_features, r)))
            self.scaling = self.lora_alpha / self.r
            self.weight.requires_grad = False
        self._initialize_lora_parameters()  # Only initialize LoRA parameters
        if fan_in_fan_out:
            self.weight.data = self.weight.data.transpose(0, 1)

    # def reset_parameters(self):
    #     nn.Linear.reset_parameters(self)
    #     if hasattr(self, 'lora_A'):
    #         nn.init.zeros_(self.lora_A)
    #         nn.init.kaiming_uniform_(self.lora_B, a=math.sqrt(5))

    def _initialize_lora_parameters(self):
        """
        Initialize only the LoRA-specific parameters (lora_A and lora_B).
        Avoid reinitializing self.weight or self.bias to preserve pretrained values.
        """
        if hasattr(self, 'lora_A'):
            # nn.init.zeros_(self.lora_A)  # A is initialized to zero
            # nn.init.kaiming_uniform_(self.lora_B, a=math.sqrt(5))  # B is initialized as per LoRA paper

            nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
            nn.init.zeros_(self.lora_B)
            
    def train(self, mode: bool = True):
        def T(w):
            return w.transpose(0, 1) if self.fan_in_fan_out else w
        nn.Linear.train(self, mode)
        if mode:
            if self.merge_weights and self.merged:
                # Make sure that the weights are not merged
                if self.r > 0:
                    self.weight.data -= T(self.lora_B @ self.lora_A) * self.scaling
                self.merged = False
        else:
            if self.merge_weights and not self.merged:
                # Merge the weights and mark it
                if self.r > 0:
                    self.weight.data += T(self.lora_B @ self.lora_A) * self.scaling
                self.merged = True

    def forward(self, x: torch.Tensor):
        def T(w):
            return w.transpose(0, 1) if self.fan_in_fan_out else w
        if self.r > 0 and not self.merged:
            result = F.linear(x, T(self.weight), bias=self.bias)
            result += (self.lora_dropout(x) @ self.lora_A.transpose(0, 1) @ self.lora_B.transpose(0, 1)) * self.scaling
            return result
        else:
            return F.linear(x, T(self.weight), bias=self.bias)

def replace_linear_with_lora(module: nn.Module, parent_name='', skip_substring='heads.head'):
    """
    Recursively replace all nn.Linear modules with LoRALayer.Linear,
    while preserving pretrained weights and biases and skipping specific submodules.

    Args:
        module: The module to recursively replace layers in.
        parent_name: Tracking the current path to correctly identify modules to skip.
        skip_substring: Substring to check in the module path to decide skipping replacement.
    """
    for name, child in list(module.named_children()):
        # Form the fully qualified name (like 'encoder.layer1.linear')
        module_path = f"{parent_name}.{name}" if parent_name else name

        # Recursively apply to child modules first
        replace_linear_with_lora(child, parent_name=module_path, skip_substring=skip_substring)

        if isinstance(child, nn.Linear) and skip_substring not in module_path:
            # Extract pretrained weights and bias
            pretrained_weights = child.weight.data.clone()
            pretrained_bias = child.bias.data.clone() if child.bias is not None else None

            # Replace the nn.Linear with LoRA-wrapped Linear
            lora_linear = xLinear(
                in_features=child.in_features,
                out_features=child.out_features,
                r=R_LORA,
                lora_alpha=LORA_ALPHA,
                lora_dropout=LORA_DROPOUT,
                pretrained_weights=pretrained_weights,
                pretrained_bias=pretrained_bias,
                
            )
            setattr(module, name, lora_linear)


def count_trainable_parameters(model):
    """
    Counts and returns the number of trainable parameters in the model.
    """
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

# def mark_only_lora_as_trainable(model: nn.Module, bias: str = 'none') -> None:
#     for n, p in model.named_parameters():
#         if 'lora_' not in n:
#             p.requires_grad = False 
#         else:
#             p.requires_grad = True

#     if bias == 'none':
#         return
#     elif bias == 'all':
#         for n, p in model.named_parameters():
#             if 'bias' in n:
#                 p.requires_grad = True
#             # else:
#             #     p.requires_grad = False

#     elif bias == 'lora_only':
#         for m in model.modules():
#             if isinstance(m, LoRALayer) and \
#                 hasattr(m, 'bias') and \
#                 m.bias is not None:
#                     m.bias.requires_grad = True
#     else:
#         raise NotImplementedError

def mark_lora_and_head_as_trainable(model: nn.Module, head_substring="heads.head", bias='none'):
    """
    Unfreeze LoRA parameters + the final classification head (by default `heads.head`).
    Everything else remains frozen.
    """
    for name, param in model.named_parameters():
        # Unfreeze LoRA parameters
        if 'lora_' in name:
            param.requires_grad = True
        # Unfreeze classification head
        elif head_substring in name:
            print("head_substring came:", name)
            param.requires_grad = True
        else:
            param.requires_grad = False

    # Optionally allow some bias fine-tuning
    if bias == 'all':
        for n, p in model.named_parameters():
            if 'bias' in n:
                p.requires_grad = True
    elif bias == 'lora_only':
        for m in model.modules():
            if isinstance(m, LoRALayer) and hasattr(m, 'bias') and m.bias is not None:
                m.bias.requires_grad = True

# Implement a linear learning rate decay
def lr_lambda(current_step: int):
    """
    Linear decay from step=0 to step=total_steps. At step=0 => 1.0; at step=total_steps => 0.0
    """
    progress = float(current_step) / float(EPOCHS * len(train_loader))
    return max(0.0, 1.0 - progress)

def compare_encoder_weights_consistency_with_xlinear(encoder_before, encoder_after):
    """
    Compare the pretrained weights and biases of nn.Linear layers in the encoder of two models.

    This ensures that the nn.Linear part of xLinear in the modified encoder
    matches the original nn.Linear weights and biases from the pretrained encoder.

    Args:
        encoder_before: Encoder module before applying LoRA and unfreezing logic.
        encoder_after: Encoder module after applying LoRA and unfreezing logic.

    Returns:
        None (prints whether weights and biases are consistent or not).
    """
    print("Comparing nn.Linear weights and biases between original encoder and modified encoder...")

    for (name_before, module_before), (name_after, module_after) in zip(
        encoder_before.named_modules(), encoder_after.named_modules()
    ):
        # Compare only nn.Linear layers in encoder_before to xLinear in encoder_after
        if isinstance(module_before, nn.Linear) and isinstance(module_after, xLinear):
            if torch.equal(module_before.weight.data, module_after.weight.data):
                # print(f"[MATCH] {name_before}: Weights are identical.")
                pass
            else:
                print(f"[MISMATCH] {name_before}: Weights differ.")

            if module_before.bias is not None and module_after.bias is not None:
                if torch.equal(module_before.bias.data, module_after.bias.data):
                    # print(f"[MATCH] {name_before}: Biases are identical.")
                    pass

                else:
                    print(f"[MISMATCH] {name_before}: Biases differ.")
            elif module_before.bias is None and module_after.bias is None:
                # print(f"[MATCH] {name_before}: Both layers have no bias.")
                pass
            else:
                print(f"[MISMATCH] {name_before}: One layer has bias while the other does not.")

    print("Comparison complete.")




torch.manual_seed(17)

transform = transforms.Compose([ 
    transforms.Resize((224, 224)), 
    transforms.ToTensor()        
])
train_dir = "/kaggle/input/tiny-imagenet/tiny-imagenet-200/tiny-imagenet-200/train"

dataset = torchvision.datasets.ImageFolder(train_dir, transform=transform)

train_data, val_data, test_data = torch.utils.data.random_split(dataset, [80000, 10000, 10000])


train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=4)
val_loader = DataLoader(val_data, batch_size=batch_size, shuffle=False, num_workers=4)
test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False, num_workers=4)
# from torchvision.utils import make_grid

# for images, _ in train_loader:
#     plt.figure(figsize=(16,8))
#     plt.axis('off')
#     plt.imshow(make_grid(images, nrow=8).permute((1, 2, 0)))
#     break




#####################################
# 2. Model Preparation
#####################################



# Load pre-trained ViT-B/16 weights from torchvision
model = vit_b_16(weights=ViT_B_16_Weights.IMAGENET1K_V1)

# Modify the classification head for CIFAR-10 (10 classes)
num_features = model.heads.head.in_features
model.heads.head = nn.Sequential(
    nn.Dropout(DROPOUT),
    nn.Linear(num_features, 200)   
)

print(f"Number of trainable parameters(total): {count_trainable_parameters(model)}")
print(f"Number of trainable parameters(heads.head): {count_trainable_parameters(model.heads.head)}")
print(f"Number of trainable parameters(encoder): {count_trainable_parameters(model.encoder)}")
print(f"Number of trainable parameters(conv_proj): {count_trainable_parameters(model.conv_proj)}")
  
# Perform the replacement on the entire ViT  
replace_linear_with_lora(model)
mark_lora_and_head_as_trainable(model, head_substring="heads.head", bias="none")
print(f"Number of trainable parameters(total): {count_trainable_parameters(model)}")
print(f"Number of trainable parameters(heads.head): {count_trainable_parameters(model.heads.head)}")
print(f"Number of trainable parameters(encoder): {count_trainable_parameters(model.encoder)}")
print(f"Number of trainable parameters(conv_proj): {count_trainable_parameters(model.conv_proj)}")
print(model.heads.head)

trainable_params_list = [name for name, param in model.named_parameters() if param.requires_grad]
print(f"Trainable parameters: {len(trainable_params_list)}")
print(trainable_params_list[:10])  # print first 10 for inspection

#####################################
# 3. Optimizer & Scheduler
#####################################

# Filter only trainable (LoRA) parameters
trainable_params = filter(lambda p: p.requires_grad, model.parameters())



optimizer = torch.optim.AdamW(trainable_params, lr=BASE_LR, weight_decay=WEIGHT_DECAY)


scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)


#####################################
# 4. Training & Validation Loop with Multi-GPU Support
#####################################

# Check if multiple GPUs are available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if torch.cuda.device_count() > 1:
    print(f"Using {torch.cuda.device_count()} GPUs")
    model = nn.DataParallel(model)  # Wrap the model for multi-GPU
model.to(device)

criterion = nn.CrossEntropyLoss()
    
for epoch in range(EPOCHS):
    model.train()
    running_loss = 0.0

    for step, (images, labels) in enumerate(train_loader):
        images = images.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)  # DataParallel splits the batch across GPUs
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        scheduler.step()  # Update learning rate each step

        running_loss += loss.item() * images.size(0)

        if step % 10 == 0:
            current_lr = scheduler.get_last_lr()[0]
            print(f"[Epoch {epoch+1}/{EPOCHS} - Step {step}] Loss: {loss.item():.4f}, LR: {current_lr:.6f}")
            # Inside your training loop
            writer.add_scalar('Training Loss', loss.item(), epoch * len(train_loader) + step)
            writer.add_scalar('Learning Rate', current_lr, epoch * len(train_loader) + step)

            # Validation
            model.eval()
            correct = 0
            total = 0
            with torch.no_grad():
                for images, labels in val_loader:
                    images = images.to(device)
                    labels = labels.to(device)
                    outputs = model(images)
                    _, predicted = torch.max(outputs, 1)
                    total += labels.size(0)
                    correct += (predicted == labels).sum().item()

    
    epoch_loss = running_loss / len(train_loader.dataset)
    val_acc = 100.0 * correct / total
    # Inside your validation loop
    writer.add_scalar('Validation Accuracy', val_acc, epoch * len(train_loader) + step)

    print(f"Epoch [{epoch+1}/{EPOCHS}], Training Loss: {epoch_loss:.4f}, Validation Acc: {val_acc:.2f}%\n")

writer.close()


Downloading: "https://download.pytorch.org/models/vit_b_16-c867db91.pth" to /root/.cache/torch/hub/checkpoints/vit_b_16-c867db91.pth
100%|██████████| 330M/330M [00:01<00:00, 218MB/s] 


Number of trainable parameters(total): 85952456
Number of trainable parameters(heads.head): 153800
Number of trainable parameters(encoder): 85207296
Number of trainable parameters(conv_proj): 590592
head_substring came: heads.head.1.weight
head_substring came: heads.head.1.bias
Number of trainable parameters(total): 596168
Number of trainable parameters(heads.head): 153800
Number of trainable parameters(encoder): 442368
Number of trainable parameters(conv_proj): 0
Sequential(
  (0): Dropout(p=0.1, inplace=False)
  (1): Linear(in_features=768, out_features=200, bias=True)
)
Trainable parameters: 74
['encoder.layers.encoder_layer_0.self_attention.out_proj.lora_A', 'encoder.layers.encoder_layer_0.self_attention.out_proj.lora_B', 'encoder.layers.encoder_layer_0.mlp.0.lora_A', 'encoder.layers.encoder_layer_0.mlp.0.lora_B', 'encoder.layers.encoder_layer_0.mlp.3.lora_A', 'encoder.layers.encoder_layer_0.mlp.3.lora_B', 'encoder.layers.encoder_layer_1.self_attention.out_proj.lora_A', 'encoder.la