In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from torchvision import transforms, datasets
import wandb
from pytorch_lightning.loggers import WandbLogger
import numpy as np
import matplotlib.pyplot as plt
import timm

## EfficientNet-V2 Model

In [2]:
class EfficientNetV2Lightning(pl.LightningModule):
    def __init__(self,
                 num_classes=10,
                 freeze_until='blocks.6',
                 dropout_rate=0.3,
                 learning_rate=1e-4,
                 optimizer='adam'):
        super().__init__()
        self.save_hyperparameters()

        # Load pretrained model
        self.model = timm.create_model('tf_efficientnetv2_s_in21k', pretrained=True)

        # Partial freezing
        freeze = True
        for name, param in self.model.named_parameters():
            if freeze and self.hparams.freeze_until in name:
                freeze = False
            param.requires_grad = not freeze

        # Replace classification head
        in_features = self.model.classifier.in_features
        self.model.classifier = nn.Sequential(
            nn.Dropout(self.hparams.dropout_rate),
            nn.Linear(in_features, self.hparams.num_classes)
        )

    def forward(self, x):
        return self.model(x)

    def configure_optimizers(self):
        trainable_params = filter(lambda p: p.requires_grad, self.parameters())
        if self.hparams.optimizer.lower() == 'adam':
            return torch.optim.Adam(trainable_params, lr=self.hparams.learning_rate)
        elif self.hparams.optimizer.lower() == 'adamw':
            return torch.optim.AdamW(trainable_params, lr=self.hparams.learning_rate)
        elif self.hparams.optimizer.lower() == 'sgd':
            return torch.optim.SGD(trainable_params, lr=self.hparams.learning_rate, momentum=0.9)
        else:
            return torch.optim.Adam(trainable_params, lr=self.hparams.learning_rate)

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.cross_entropy(logits, y)
        acc = (logits.argmax(1) == y).float().mean()
        self.log("train_loss", loss, prog_bar=True)
        self.log("train_acc", acc, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.cross_entropy(logits, y)
        acc = (logits.argmax(1) == y).float().mean()
        self.log("val_loss", loss, prog_bar=True)
        self.log("val_acc", acc, prog_bar=True)

    def test_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.cross_entropy(logits, y)
        acc = (logits.argmax(1) == y).float().mean()
        self.log("test_loss", loss, prog_bar=True)
        self.log("test_acc", acc, prog_bar=True)


## Data Pipeline

In [3]:
class INaturalistDataModule(pl.LightningDataModule):
    def __init__(self, batch_size=32, num_workers=4, data_augmentation=False):
        super().__init__()
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.data_augmentation = data_augmentation
        
    def setup(self, stage=None):
        # Define transformations
        if self.data_augmentation:
            train_transform = transforms.Compose([
                transforms.Resize((384, 384)),
                transforms.RandomHorizontalFlip(),
                transforms.RandomRotation(10),
                transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),
                transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
            ])
        else:
            train_transform = transforms.Compose([
                transforms.Resize((384, 384)),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
            ])
            
        val_transform = transforms.Compose([
            transforms.Resize((384, 384)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        
        # Load dataset
        full_dataset = datasets.ImageFolder(root="/kaggle/input/nature-12k/inaturalist_12K/train", transform=train_transform)
        # Split into train/val/test
        train_size = int(0.8 * len(full_dataset))
        val_size = len(full_dataset) - train_size
        
        self.train_dataset, self.val_dataset = random_split(
            full_dataset, [train_size, val_size]
        )
        
        # Apply different transforms to validation and test sets
        self.val_dataset.dataset = datasets.ImageFolder(
            root="/kaggle/input/nature-12k/inaturalist_12K/train", transform=val_transform
        )
        self.test_dataset = datasets.ImageFolder(
            root="/kaggle/input/nature-12k/inaturalist_12K/val", transform=val_transform
        )
    
    def train_dataloader(self):
        return DataLoader(
            self.train_dataset, 
            batch_size=self.batch_size, 
            shuffle=True, 
            num_workers=self.num_workers
        )
    
    def val_dataloader(self):
        return DataLoader(
            self.val_dataset, 
            batch_size=self.batch_size, 
            shuffle=False, 
            num_workers=self.num_workers
        )
    
    def test_dataloader(self):
        return DataLoader(
            self.test_dataset, 
            batch_size=self.batch_size, 
            shuffle=False, 
            num_workers=self.num_workers
        )

## Tuning with WandB

In [4]:
def train_with_wandb(config=None):
    """Train EfficientNetV2-S with WandB hyperparameter tuning"""
    with wandb.init(config=config):
        config = wandb.config

        # Model
        model = EfficientNetV2Lightning(
            num_classes=10,
            freeze_until=config.freeze_until,
            dropout_rate=config.dropout_rate,
            learning_rate=config.learning_rate,
            optimizer=config.optimizer
        )

        # Data module (resize to 384x384 for EfficientNet)
        data_module = INaturalistDataModule(
            batch_size=config.batch_size,
            data_augmentation=config.data_augmentation,
        )

        # WandB logger
        wandb_logger = WandbLogger(project="inaturalist-efficientnet")

        # Callbacks
        early_stop_callback = EarlyStopping(
            monitor='val_loss',
            patience=10,
            mode='min'
        )

        checkpoint_callback = ModelCheckpoint(
            monitor='val_acc',
            dirpath='./checkpoints/',
            filename='inaturalist-efficientnet-{epoch:02d}-{val_acc:.2f}',
            save_top_k=1,
            mode='max'
        )

        # Trainer
        trainer = pl.Trainer(
            max_epochs=10,
            logger=wandb_logger,
            callbacks=[early_stop_callback, checkpoint_callback],
            log_every_n_steps=10
        )

        trainer.fit(model, data_module)


In [5]:
sweep_config = {
    'name': 'EfficientNetV2',
    'method': 'bayes',
    'metric': {'name': 'val_acc', 'goal': 'maximize'},
    'parameters': {
        'learning_rate': {
            'min': 1e-5,
            'max': 5e-4
        },
        'dropout_rate': {
            'values': [0.1, 0.2, 0.3, 0.4]
        },
        'freeze_until': {
            'values': ['blocks.4', 'blocks.5', 'blocks.6']
        },
        'optimizer': {
            'values': ['adamw', 'adam']
        },
        'batch_size': {
            'values': [32, 64]
        },
        'data_augmentation': {
            'values': [True, False]
        }
    }
}


In [5]:
from kaggle_secrets import UserSecretsClient

user_secrets = UserSecretsClient()

In [6]:
wandb_api = user_secrets.get_secret("wandb_api")

In [7]:
wandb.login(key=wandb_api)

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mda24m027[0m ([33mda24m027-indian-institute-of-technology-madras[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [9]:
sweep_id = wandb.sweep(sweep_config, project="DA6401_Assignment2")

Create sweep with ID: ygknqu29
Sweep URL: https://wandb.ai/da24m027-indian-institute-of-technology-madras/DA6401_Assignment2/sweeps/ygknqu29


In [None]:
wandb.agent(sweep_id, train_with_wandb, count=20)

[34m[1mwandb[0m: Agent Starting Run: z4qg21fm with config:
[34m[1mwandb[0m: 	batch_size: 32
[34m[1mwandb[0m: 	data_augmentation: True
[34m[1mwandb[0m: 	dropout_rate: 0.4
[34m[1mwandb[0m: 	freeze_until: blocks.6
[34m[1mwandb[0m: 	learning_rate: 0.00014997242936342557
[34m[1mwandb[0m: 	optimizer: adam
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


  model = create_fn(


model.safetensors:   0%|          | 0.00/193M [00:00<?, ?B/s]

/usr/local/lib/python3.11/dist-packages/pytorch_lightning/loggers/wandb.py:397: There is a wandb run already in progress and newly created instances of `WandbLogger` will reuse this run. If this is not desired, call `wandb.finish()` before instantiating `WandbLogger`.


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

## Evaluating Best Model

In [8]:
def visualize_test_results(model, data_module):
    """Visualize test results in a 10x3 grid (10 classes, 3 per class)"""
    import matplotlib.pyplot as plt
    import numpy as np
    import torch
    from torchvision import transforms
    
    # Set model to evaluation mode
    model.eval()
    
    # Get class names
    class_names = data_module.test_dataset.classes if hasattr(data_module.test_dataset, 'classes') else [f"Class {i}" for i in range(10)]
    
    # Get test dataloader
    test_loader = data_module.test_dataloader()
    
    # Create dictionary to store examples for each class
    class_examples = {i: [] for i in range(10)}
    
    # Get examples for each class
    with torch.no_grad():
        for images, labels in test_loader:
            # Get predictions
            outputs = model(images)
            _, preds = torch.max(outputs, 1)
            
            # Store examples
            for i, (image, label, pred) in enumerate(zip(images, labels, preds)):
                label_idx = label.item()
                if len(class_examples[label_idx]) < 3:
                    # Convert tensor to numpy for visualization
                    img = image.cpu().numpy().transpose(1, 2, 0)
                    
                    # Denormalize if necessary
                    if hasattr(data_module, 'normalize_transform'):
                        mean = data_module.normalize_transform.mean
                        std = data_module.normalize_transform.std
                        img = img * np.array(std) + np.array(mean)
                        
                    img = np.clip(img, 0, 1)
                    
                    class_examples[label_idx].append({
                        'image': img,
                        'true': class_names[label_idx],
                        'pred': class_names[pred.item()]
                    })
            
            # Check if we have enough examples
            if all(len(examples) >= 3 for examples in class_examples.values()):
                break
    
    # Create 10x3 grid for visualization
    fig, axes = plt.subplots(10, 3, figsize=(15, 30))
    
    for class_idx in range(10):
        examples = class_examples[class_idx]
        for i in range(min(3, len(examples))):
            ax = axes[class_idx, i]
            example = examples[i]
            
            # Display image
            ax.imshow(example['image'])
            
            # Set title with true and predicted labels
            title = f"True: {example['true']}\nPred: {example['pred']}"
            color = 'green' if example['true'] == example['pred'] else 'red'
            ax.set_title(title, color=color)
            
            # Remove axis ticks
            ax.set_xticks([])
            ax.set_yticks([])
    
    plt.tight_layout()
    
    # Log figure to wandb
    wandb.log({"test_predictions": wandb.Image(fig)})
    
    # Close the figure
    plt.close(fig)

In [9]:
#Function is similar to train_with_wandb() with added evaluation and visualization code
def train_and_evaluate(config=None):
    """Train and Evaluate best model config"""
    with wandb.init(config=config, project="DA6401_Assignment2"):
        # Get hyperparameters from wandb
        config = wandb.config
        
        # Create model with the hyperparameters
        model = EfficientNetV2Lightning(
            num_classes=10,
            freeze_until=config.freeze_until,
            dropout_rate=config.dropout_rate,
            learning_rate=config.learning_rate,
            optimizer=config.optimizer
        )
        
        # Create data module
        data_module = INaturalistDataModule(
            batch_size=config.batch_size,
            data_augmentation=config.data_augmentation
        )
        
        # Create WandB logger
        wandb_logger = WandbLogger(project="inaturalist-cnn")
        
        # Create callbacks
        early_stop_callback = EarlyStopping(
            monitor='val_loss',
            patience=10,
            mode='min'
        )
        
        checkpoint_callback = ModelCheckpoint(
            monitor='val_acc',
            dirpath='./checkpoints/',
            filename='inaturalist-efficientnet-{epoch:02d}-{val_acc:.2f}',
            save_top_k=1,
            mode='max'
        )
        
        # Create trainer
        trainer = pl.Trainer(
            max_epochs=30,
            logger=wandb_logger,
            callbacks=[early_stop_callback, checkpoint_callback],
            log_every_n_steps=10
        )
        
        # Train the model
        trainer.fit(model, data_module)
        
        # Test the model
        test_result = trainer.test(model, data_module)
        
        # Log final metrics
        wandb.log({
            "test_accuracy": test_result[0]["test_acc"],
            "test_loss": test_result[0]["test_loss"]
        })
        
        # Visualize test results in a 10x3 grid
        visualize_test_results(model, data_module)

In [10]:
api = wandb.Api()
sweep = api.sweep("da24m027-indian-institute-of-technology-madras/DA6401_Assignment2/ygknqu29")
best_run = sweep.best_run()
best_run_config = best_run.config

[34m[1mwandb[0m: Sorting runs by -summary_metrics.val_acc


In [11]:
best_run_config

{'optimizer': 'adam',
 'batch_size': 64,
 'num_classes': 10,
 'dropout_rate': 0.1,
 'freeze_until': 'blocks.4',
 'learning_rate': 4.1752527543235864e-05,
 'data_augmentation': True}

In [12]:
train_and_evaluate(config=best_run_config)

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


  model = create_fn(


model.safetensors:   0%|          | 0.00/193M [00:00<?, ?B/s]

/usr/local/lib/python3.11/dist-packages/pytorch_lightning/loggers/wandb.py:397: There is a wandb run already in progress and newly created instances of `WandbLogger` will reuse this run. If this is not desired, call `wandb.finish()` before instantiating `WandbLogger`.


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Testing: |          | 0/? [00:00<?, ?it/s]

0,1
epoch,▁▁▁▂▂▃▃▃▃▃▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇▇▇███████
test_acc,▁
test_accuracy,▁
test_loss,▁▁
train_acc,▁▁▄▄▅▆▆▇▇▆▇▇▇▇█▇█▇██████████████████████
train_loss,█▇▆▆▃▂▂▂▂▂▂▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
trainer/global_step,▁▁▁▂▂▃▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇█████
val_acc,▁▆▇▇▇███████████
val_loss,█▃▂▁▁▁▁▁▁▁▁▂▁▂▂▂

0,1
epoch,16.0
test_acc,0.9135
test_accuracy,0.9135
test_loss,0.33643
train_acc,1.0
train_loss,0.00973
trainer/global_step,2000.0
val_acc,0.913
val_loss,0.34524
