In [6]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.datasets import ImageFolder
from PIL import Image
from tqdm import tqdm

In [7]:
# Configuration (from repo config file)
DATA_ROOT = r"F:\datasets\archive\tiered_imagenet"
NUM_CLASSES = 351
BATCH_SIZE = 64
IMAGE_SIZE = 84
BASE_LR = 0.01
LR_GAMMA = 0.1
MOMENTUM = 0.9
WEIGHT_DECAY = 5e-4
EPOCHS = 90
LR_STEPS = [30, 60]
NUM_WAYS = 5
NUM_SHOTS = [1, 6]
NUM_QUERIES = 15
NUM_TASKS = 10000
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [8]:
class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_channels, out_channels, stride=1):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, 
                              stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3,
                              stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != self.expansion * out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, self.expansion * out_channels,
                         kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion * out_channels)
            )

    def forward(self, x):
        out = nn.functional.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = nn.functional.relu(out)
        return out

class ResNet18(nn.Module):
    def __init__(self, num_classes=NUM_CLASSES):
        super().__init__()
        self.in_channels = 64
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        
        self._initialize_weights()  # Initialize weights to improve convergence

        # Residual blocks
        self.layer1 = self._make_layer(BasicBlock, 64, 2, stride=1)
        self.layer2 = self._make_layer(BasicBlock, 128, 2, stride=2)
        self.layer3 = self._make_layer(BasicBlock, 256, 2, stride=2)
        self.layer4 = self._make_layer(BasicBlock, 512, 2, stride=2)
        
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512 * BasicBlock.expansion, num_classes) if num_classes else nn.Identity()

    # Initialize weights to improve convergence
    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
        # Zero-initialize last BN in each residual branch
        for m in self.modules():
            if isinstance(m, BasicBlock):
                nn.init.constant_(m.bn2.weight, 0)

    def _make_layer(self, block, out_channels, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_channels, out_channels, stride))
            self.in_channels = out_channels * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        x = nn.functional.relu(self.bn1(self.conv1(x)))
        x = self.maxpool(x)
        
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x

In [9]:
class DenseLayer(nn.Module):
    def __init__(self, in_channels, growth_rate):
        super().__init__()
        self.bn = nn.BatchNorm2d(in_channels)
        self.conv = nn.Conv2d(in_channels, growth_rate, kernel_size=3, padding=1, bias=False)
        
    def forward(self, x):
        out = self.conv(nn.functional.relu(self.bn(x)))
        return torch.cat([x, out], 1)
# Data Transforms (official preprocessing)
train_transform = transforms.Compose([
    transforms.RandomResizedCrop(IMAGE_SIZE),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

test_transform = transforms.Compose([
    transforms.Resize(int(IMAGE_SIZE * 1.15)),
    transforms.CenterCrop(IMAGE_SIZE),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

def get_dataloader(split):
    return DataLoader(
        ImageFolder(os.path.join(DATA_ROOT, split), 
                   transform=train_transform if split == "train" else test_transform),
        batch_size=BATCH_SIZE,
        shuffle=(split == "train"),
        num_workers=4,
        pin_memory=True
    )

def train_model():
    train_loader = get_dataloader("train")
    val_loader = get_dataloader("val")
    
    # Change to ResNet18
    model = ResNet18().to(DEVICE)
    optimizer = optim.SGD(model.parameters(), lr=BASE_LR, momentum=MOMENTUM, weight_decay=WEIGHT_DECAY)
    scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=LR_STEPS, gamma=LR_GAMMA)
    criterion = nn.CrossEntropyLoss()

    #model output check
    print("\n--- Model Output Check ---")
    dummy_input = torch.randn(2, 3, IMAGE_SIZE, IMAGE_SIZE).to(DEVICE)
    output = model(dummy_input)
    print("Output shape:", output.shape)
    print("Sample outputs:", output[:2])

    best_val_acc = 0
    for epoch in range(EPOCHS):
        # Training
        model.train()
        train_loss = 0
        for inputs, targets in tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS}"):
            inputs, targets = inputs.to(DEVICE), targets.to(DEVICE)
            
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
        
        # Validation
        model.eval()
        val_correct = 0
        val_total = 0
        with torch.no_grad():
            for inputs, targets in val_loader:
                inputs, targets = inputs.to(DEVICE), targets.to(DEVICE)
                outputs = model(inputs)
                _, predicted = torch.max(outputs, 1)
                val_total += targets.size(0)
                val_correct += (predicted == targets).sum().item()
        
        val_acc = val_correct / val_total
        print(f"Train Loss: {train_loss/len(train_loader):.4f} | Val Acc: {val_acc:.4f}")

        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save(model.state_dict(), "resnet18_tiered.pth")  # Changed filename
        
        scheduler.step()
    
    return model

#evaluate fewshot

from sklearn.metrics import classification_report

def evaluate_fewshot(model):
    test_loader = get_dataloader("test")
    train_loader = get_dataloader("train")

    # Precompute mean feature from base classes
    model.eval()
    with torch.no_grad():
        train_features = torch.cat([model(inputs.to(DEVICE)).cpu() 
                                  for inputs, _ in tqdm(train_loader, 
                                                        desc="Base Features")])
        mean_feature = train_features.mean(dim=0, keepdim=True)

    def apply_transforms(features, transform):
        if transform == "L2N":
            return features / features.norm(p=2, dim=1, keepdim=True).clamp(min=1e-7)
        elif transform == "CL2N":
            centered = features - mean_feature
            return centered / centered.norm(p=2, dim=1, keepdim=True).clamp(min=1e-7)
        return features

    results = {}
    for num_shots in NUM_SHOTS:
        print(f"\n{num_shots}-shot Classification Report:")
        
        # Storage for all predictions and labels
        all_preds = {t: [] for t in ["UN", "L2N", "CL2N"]}
        all_labels = {t: [] for t in ["UN", "L2N", "CL2N"]}

        for _ in tqdm(range(NUM_TASKS), desc="Tasks"):
            # Sample task
            class_indices = np.random.choice(len(test_loader.dataset.classes), 
                                           NUM_WAYS, replace=False)
            support, query = [], []
            for c in class_indices:
                samples = [i for i, (_, y) in enumerate(test_loader.dataset.samples) 
                          if y == c]
                selected = np.random.choice(samples, num_shots + NUM_QUERIES, False)
                support.extend(selected[:num_shots])
                query.extend(selected[num_shots:])

            # Process images
            with torch.no_grad():
                # Support features
                sup_inputs = torch.stack([
                    test_transform(Image.open(test_loader.dataset.samples[i][0]).convert("RGB"))
                    for i in support
                ])
                sup_features = model(sup_inputs.to(DEVICE)).cpu()
                
                # Query features
                qry_inputs = torch.stack([
                    test_transform(Image.open(test_loader.dataset.samples[i][0]).convert("RGB"))
                    for i in query
                ])
                qry_features = model(qry_inputs.to(DEVICE)).cpu()

            # Generate labels once per task
            labels = torch.arange(NUM_WAYS).repeat_interleave(NUM_QUERIES).numpy()

            for transform in ["UN", "L2N", "CL2N"]:
                # Transform features
                t_sup = apply_transforms(sup_features, transform)
                t_qry = apply_transforms(qry_features, transform)

                # Calculate prototypes
                prototypes = torch.stack([
                    t_sup[i*num_shots:(i+1)*num_shots].mean(0) 
                    for i in range(NUM_WAYS)
                ])

                # Predictions
                dists = torch.cdist(t_qry, prototypes)
                preds = dists.argmin(dim=1).numpy()

                # Store results
                all_preds[transform].extend(preds)
                all_labels[transform].extend(labels)

        # Generate reports for each transform
        for transform in ["UN", "L2N", "CL2N"]:
            print(f"\nTransform: {transform}")
            print(classification_report(
                all_labels[transform],
                all_preds[transform],
                target_names=[f"Class {i}" for i in range(NUM_WAYS)],
                digits=4,
                zero_division=0
            ))

    return results



In [10]:
if __name__ == "__main__":
    #data pipeline check
    
    
    # Train or load model
    if not os.path.exists("resnet18_tiered.pth"):  # Changed filename
        print("Training ResNet-18...")
        model = train_model()
    else:
        print("Loading pretrained model...")
        model = ResNet18(num_classes=None).to(DEVICE)
        state_dict = torch.load("resnet18_tiered.pth", map_location=DEVICE)  # Changed filename
        # Filter final FC layer weights
        state_dict = {k: v for k, v in state_dict.items() if not k.startswith("fc")}
        model.load_state_dict(state_dict, strict=False)
    
    # Remove final FC layer for feature extraction
    model.fc = nn.Identity()
    
    # Evaluate
    evaluate_fewshot(model)

Loading pretrained model...


Base Features: 100%|██████████| 7011/7011 [09:08<00:00, 12.77it/s]



1-shot Classification Report:


Tasks: 100%|██████████| 10000/10000 [50:43<00:00,  3.29it/s] 



Transform: UN
              precision    recall  f1-score   support

     Class 0     0.4532    0.4567    0.4550    150000
     Class 1     0.4607    0.4592    0.4599    150000
     Class 2     0.4606    0.4554    0.4580    150000
     Class 3     0.4546    0.4558    0.4552    150000
     Class 4     0.4573    0.4591    0.4582    150000

    accuracy                         0.4573    750000
   macro avg     0.4573    0.4573    0.4573    750000
weighted avg     0.4573    0.4573    0.4573    750000


Transform: L2N
              precision    recall  f1-score   support

     Class 0     0.4775    0.4776    0.4775    150000
     Class 1     0.4822    0.4816    0.4819    150000
     Class 2     0.4821    0.4787    0.4804    150000
     Class 3     0.4752    0.4785    0.4769    150000
     Class 4     0.4797    0.4803    0.4800    150000

    accuracy                         0.4793    750000
   macro avg     0.4793    0.4793    0.4793    750000
weighted avg     0.4793    0.4793    0.4793   

Tasks: 100%|██████████| 10000/10000 [55:32<00:00,  3.00it/s] 



Transform: UN
              precision    recall  f1-score   support

     Class 0     0.6405    0.6414    0.6409    150000
     Class 1     0.6407    0.6375    0.6391    150000
     Class 2     0.6417    0.6444    0.6431    150000
     Class 3     0.6437    0.6450    0.6443    150000
     Class 4     0.6410    0.6393    0.6401    150000

    accuracy                         0.6415    750000
   macro avg     0.6415    0.6415    0.6415    750000
weighted avg     0.6415    0.6415    0.6415    750000


Transform: L2N
              precision    recall  f1-score   support

     Class 0     0.6445    0.6445    0.6445    150000
     Class 1     0.6431    0.6410    0.6420    150000
     Class 2     0.6449    0.6483    0.6466    150000
     Class 3     0.6473    0.6482    0.6477    150000
     Class 4     0.6453    0.6430    0.6442    150000

    accuracy                         0.6450    750000
   macro avg     0.6450    0.6450    0.6450    750000
weighted avg     0.6450    0.6450    0.6450   