In [None]:
import torch
import torch.nn as nn
from torchvision import transforms, datasets
from torch.utils.data import DataLoader
from torchvision.models import vit_b_16, ViT_B_16_Weights

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms, datasets, models
from torch.utils.data import DataLoader
from torchvision.models import vit_b_16, ViT_B_16_Weights
from typing import Dict
import math
from typing import Optional, List

import time
import copy
import numpy as np
import matplotlib.pyplot as plt

import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

import wandb
wandb.login(key="")
from sklearn.metrics import f1_score
import yaml

# Load configuration
def load_config(path='/kaggle/input/cccccc/config.yaml'):
    with open(path, 'r') as f:
        return yaml.safe_load(f)

config = load_config()

# Unpack config
R_LORA = config.get('adapter_rank')
SAMPLING_STEPS = config.get('num_subsamples')
SUBADAPTER_UPDATE_FREQ = config.get('subadapter_update_freq')
REG_LAMBDA = config.get('reg_lambda')
REG_DELTA = config.get('reg_delta')

# Initialize WandB with config
wandb.init(project="subadapter_peft", config=config)

#####################################
# Hyperparameters
#####################################
EPOCHS = 1
batch_size = 64            # Adjust as your GPU allows
BASE_LR = 1e-3
WEIGHT_DECAY = 0.03
DROPOUT = 0.1


class LoRALayer():
    def __init__(
        self,
        r: int,
        lora_alpha: int,
        lora_dropout: float,
        merge_weights: bool,
    ):
        self.r = r
        self.lora_alpha = lora_alpha

        # Optional dropout
        if lora_dropout > 0.:
            self.lora_dropout = nn.Dropout(p=lora_dropout)
        else:
            self.lora_dropout = lambda x: x
        # Mark the weight as unmerged
        self.merged = False
        self.merge_weights = merge_weights


class xLinear(nn.Linear, LoRALayer):
    # LoRA implemented in a dense layer
    def __init__(
        self,
        in_features: int,
        out_features: int,
        r: int = 0,
        lora_alpha: int = 32,
        lora_dropout: float = 0.0,
        fan_in_fan_out: bool = False,
        merge_weights: bool = True,
        pretrained_weights=None,  # Added to accept pretrained weights
        pretrained_bias=None,     # Added to accept pretrained bias
        **kwargs
    ):
        super().__init__(in_features, out_features, **kwargs)
        LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout,
                           merge_weights=merge_weights)

        self.fan_in_fan_out = fan_in_fan_out
        if pretrained_weights is not None:
            self.weight.data = pretrained_weights
        if pretrained_bias is not None:
            self.bias.data = pretrained_bias

        # Actual trainable parameters
        if r > 0:
            self.lora_A = nn.Parameter(self.weight.new_zeros((r, in_features)))
            self.lora_B = nn.Parameter(self.weight.new_zeros((out_features, r)))
            self.scaling = self.lora_alpha / self.r
            self.weight.requires_grad = False
        self._initialize_lora_parameters()  # Only initialize LoRA parameters
        if fan_in_fan_out:
            self.weight.data = self.weight.data.transpose(0, 1)

    def _initialize_lora_parameters(self):
        """
        Initialize only the LoRA-specific parameters (lora_A and lora_B).
        Avoid reinitializing self.weight or self.bias to preserve pretrained values.
        """
        if hasattr(self, 'lora_A'):
            nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
            nn.init.zeros_(self.lora_B)
            
    def train(self, mode: bool = True):
        def T(w):
            return w.transpose(0, 1) if self.fan_in_fan_out else w
        nn.Linear.train(self, mode)
        if mode:
            if self.merge_weights and self.merged:
                # Make sure that the weights are not merged
                if self.r > 0:
                    self.weight.data -= T(self.lora_B @ self.lora_A) * self.scaling
                self.merged = False
        else:
            if self.merge_weights and not self.merged:
                # Merge the weights and mark it
                if self.r > 0:
                    self.weight.data += T(self.lora_B @ self.lora_A) * self.scaling
                self.merged = True

    def forward(self, x: torch.Tensor):
        def T(w):
            return w.transpose(0, 1) if self.fan_in_fan_out else w
        if self.r > 0 and not self.merged:
            result = F.linear(x, T(self.weight), bias=self.bias)
            result += (self.lora_dropout(x) @ self.lora_A.transpose(0, 1) @ self.lora_B.transpose(0, 1)) * self.scaling
            return result
        else:
            return F.linear(x, T(self.weight), bias=self.bias)


# New Sampling-based Subadapter Linear Layer
class SubAdapterLinear(nn.Module):
    def __init__(self, in_features, out_features, r, s, pretrained_weights=None, pretrained_bias=None):
        super().__init__()
        # Base pretrained weights
        self.weight = nn.Parameter(pretrained_weights.clone())
        self.bias = nn.Parameter(pretrained_bias.clone()) if pretrained_bias is not None else None
        # Main adapter matrices
        self.A = nn.Parameter(torch.zeros(in_features, r))  # in_features x r
        self.B = nn.Parameter(torch.zeros(r, out_features))  # r x out_features
        self.r = r
        self.s = s
        self._initialize_subadapter()

    def _initialize_subadapter(self):
        nn.init.kaiming_uniform_(self.A, a=math.sqrt(5))
        nn.init.zeros_(self.B)

    def forward(self, x: torch.Tensor):
        # Base linear output
        base = F.linear(x, self.weight, bias=self.bias)
        # Sampling probabilities based on A columns squared norm
        probs = (self.A ** 2).sum(dim=0)
        probs = probs / probs.sum()
        # Sample s indices with replacement
        idx = torch.multinomial(probs, self.s, replacement=True)
        # Subadapter pointers (views)
        S = self.A[:, idx]  # in_features x s
        R = self.B[idx, :]  # s x out_features
        sub_out = (x @ S @ R) / self.s
        return base + sub_out


def replace_linear_with_subadapter(module: nn.Module, parent_name: str = '', skip_substring: str = 'heads.head'):
    for name, child in list(module.named_children()):
        module_path = f"{parent_name}.{name}" if parent_name else name
        replace_linear_with_subadapter(child, parent_name=module_path, skip_substring=skip_substring)
        if isinstance(child, nn.Linear) and skip_substring not in module_path:
            w = child.weight.data.clone()
            b = child.bias.data.clone() if child.bias is not None else None
            sub = SubAdapterLinear(
                in_features=child.in_features,
                out_features=child.out_features,
                r=R_LORA,
                s=SAMPLING_STEPS,
                pretrained_weights=w,
                pretrained_bias=b
            )
            setattr(module, name, sub)


def count_trainable_parameters(model):
    """
    Counts and returns the number of trainable parameters in the model.
    """
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


def mark_lora_and_head_as_trainable(model: nn.Module, head_substring="heads.head", bias='none'):
    """
    Unfreeze LoRA parameters + the final classification head (by default `heads.head`).
    Everything else remains frozen.
    """
    for name, param in model.named_parameters():
        # Unfreeze LoRA parameters or SubAdapterLinear parameters A and B
        if 'lora_' in name or name.endswith('.A') or name.endswith('.B'):
            param.requires_grad = True
        # Unfreeze classification head
        elif head_substring in name:
            print("head_substring came:", name)
            param.requires_grad = True
        # Everything else remains frozen
        else:
            param.requires_grad = False

    # Optionally allow some bias fine-tuning
    if bias == 'all':
        for n, p in model.named_parameters():
            if 'bias' in n:
                p.requires_grad = True
    elif bias == 'lora_only':
        for m in model.modules():
            if isinstance(m, LoRALayer) and hasattr(m, 'bias') and m.bias is not None:
                m.bias.requires_grad = True


# Implement a linear learning rate decay
def lr_lambda(current_step: int):
    """
    Linear decay from step=0 to step=total_steps. At step=0 => 1.0; at step=total_steps => 0.0
    """
    progress = float(current_step) / float(EPOCHS * len(train_loader))
    return max(0.0, 1.0 - progress)


torch.manual_seed(17)

transform = transforms.Compose([ 
    transforms.Resize((224, 224)), 
    transforms.ToTensor()        
])
train_dir = "/kaggle/input/tiny-imagenet/tiny-imagenet-200/tiny-imagenet-200/train"

dataset = torchvision.datasets.ImageFolder(train_dir, transform=transform)

from collections import defaultdict
import random

# Few-shot training: select 5 samples per class for the training subset
label_to_indices = defaultdict(list)
for idx, (_, label) in enumerate(dataset):
    label_to_indices[label].append(idx)
fewshot_indices = []
for label, inds in label_to_indices.items():
    fewshot_indices.extend(random.sample(inds, 5))

# Create train and validation subsets
train_dataset = torch.utils.data.Subset(dataset, fewshot_indices)
all_indices = set(range(len(dataset)))
val_indices = list(all_indices - set(fewshot_indices))
val_dataset = torch.utils.data.Subset(dataset, val_indices)

# Data loaders
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4)
# Keep test loader if needed (could point to separate test set)
test_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4)


#####################################
# 2. Model Preparation
#####################################


# Load pre-trained ViT-B/16 weights from torchvision
model = vit_b_16(weights=ViT_B_16_Weights.IMAGENET1K_V1)

# Modify the classification head for CIFAR-10 (10 classes)
num_features = model.heads.head.in_features
model.heads.head = nn.Sequential(
    nn.Dropout(DROPOUT),
    nn.Linear(num_features, 200)   
)

print(f"Number of trainable parameters(total): {count_trainable_parameters(model)}")
print(f"Number of trainable parameters(heads.head): {count_trainable_parameters(model.heads.head)}")
print(f"Number of trainable parameters(encoder): {count_trainable_parameters(model.encoder)}")
print(f"Number of trainable parameters(conv_proj): {count_trainable_parameters(model.conv_proj)}")
  
# Perform the replacement using sampling-based subadapters
replace_linear_with_subadapter(model)
mark_lora_and_head_as_trainable(model, head_substring="heads.head", bias="none")

# Ensure subadapter A and B are trainable as well
for module in model.modules():
    if isinstance(module, SubAdapterLinear):
        module.A.requires_grad = True
        module.B.requires_grad = True

print(f"Number of trainable parameters(total): {count_trainable_parameters(model)}")
print(f"Number of trainable parameters(heads.head): {count_trainable_parameters(model.heads.head)}")
print(f"Number of trainable parameters(encoder): {count_trainable_parameters(model.encoder)}")
print(f"Number of trainable parameters(conv_proj): {count_trainable_parameters(model.conv_proj)}")
print(model.heads.head)

trainable_params_list = [name for name, param in model.named_parameters() if param.requires_grad]
print(f"Trainable parameters: {len(trainable_params_list)}")
print(trainable_params_list[:10])  # print first 10 for inspection

#####################################
# 3. Optimizer & Scheduler
#####################################

# Filter only trainable (LoRA) parameters
trainable_params = filter(lambda p: p.requires_grad, model.parameters())

optimizer = torch.optim.AdamW(trainable_params, lr=BASE_LR, weight_decay=WEIGHT_DECAY)

scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)


#####################################
# 4. Training & Validation Loop with Multi-GPU Support
#####################################

# Check if multiple GPUs are available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if torch.cuda.device_count() > 1:
    print(f"Using {torch.cuda.device_count()} GPUs")
    model = nn.DataParallel(model)  # Wrap the model for multi-GPU
model.to(device)

criterion = nn.CrossEntropyLoss()
    
for epoch in range(EPOCHS):
    model.train()
    running_loss = 0.0

    for step, (images, labels) in enumerate(train_loader):
        images = images.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        # combined loss
        loss_cls = criterion(outputs, labels)
        reg_loss = 0.0
        for module in model.modules():
            if isinstance(module, SubAdapterLinear):
                A = module.A; B = module.B
                C = torch.sum(torch.norm(A, dim=0) * torch.norm(B, dim=1))
                m, _ = A.shape; p = B.shape[1]
                bound = C / math.sqrt(module.s) * math.sqrt(2 * math.log(m * p) + 2 * math.log(2 / REG_DELTA))
                # reg_loss = reg_loss + bound
        loss = loss_cls + REG_LAMBDA * reg_loss

        loss.backward()
        optimizer.step()
        scheduler.step()

        running_loss += loss.item() * images.size(0)

        if step % 10 == 0:
            current_lr = scheduler.get_last_lr()[0]
            wandb.log({
                'Training Loss': loss.item(),
                'Learning Rate': current_lr
            }, step=epoch * len(train_loader) + step)

    # After each epoch, perform validation
    model.eval()
    all_preds, all_labels = [], []
    val_loss = 0.0
    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss_v = criterion(outputs, labels)
            val_loss += loss_v.item() * images.size(0)
            preds = torch.argmax(outputs, dim=1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    val_loss /= len(val_loader.dataset)
    val_acc = 100.0 * np.mean(np.array(all_preds) == np.array(all_labels))
    val_f1 = f1_score(all_labels, all_preds, average='macro')
    wandb.log({
        'Validation Loss': val_loss,
        'Validation Accuracy': val_acc,
        'Validation F1': val_f1    
    }, step=(epoch+1) * len(train_loader))


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


Number of trainable parameters(total): 85952456
Number of trainable parameters(heads.head): 153800
Number of trainable parameters(encoder): 85207296
Number of trainable parameters(conv_proj): 590592
head_substring came: heads.head.1.weight
head_substring came: heads.head.1.bias
Number of trainable parameters(total): 7231688
Number of trainable parameters(heads.head): 153800
Number of trainable parameters(encoder): 7077888
Number of trainable parameters(conv_proj): 0
Sequential(
  (0): Dropout(p=0.1, inplace=False)
  (1): Linear(in_features=768, out_features=200, bias=True)
)
Trainable parameters: 74
['encoder.layers.encoder_layer_0.self_attention.out_proj.A', 'encoder.layers.encoder_layer_0.self_attention.out_proj.B', 'encoder.layers.encoder_layer_0.mlp.0.A', 'encoder.layers.encoder_layer_0.mlp.0.B', 'encoder.layers.encoder_layer_0.mlp.3.A', 'encoder.layers.encoder_layer_0.mlp.3.B', 'encoder.layers.encoder_layer_1.self_attention.out_proj.A', 'encoder.layers.encoder_layer_1.self_attenti

In [16]:
wandb.finish()