In [None]:
!pip install lightning

In [None]:
import torch
import torch.nn as nn
from torchvision.transforms import transforms
from torch.utils.data import DataLoader
import lightning as L

In [None]:
from datasets import load_dataset

ds = load_dataset("uoft-cs/cifar100")

In [None]:
ds

In [None]:
train,test = ds["train"],ds['test']

In [None]:
train[0]

In [None]:
class CustomDataset(torch.utils.data.Dataset):
  def __init__(self, data, transform=None):
    self.data = data
    self.transform = transform
  def __len__(self):
    return len(self.data)
  def __getitem__(self, idx):
    image = self.data[idx]["img"]
    label = self.data[idx]["fine_label"]
    if self.transform:
      image = self.transform(image)
    return image, label

In [None]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])
train_dataset = CustomDataset(train, transform=transform)
test_dataset = CustomDataset(test, transform=transform)


In [None]:
import pytorch_lightning as L
from torch.utils.data import DataLoader

class CustomDataModule(L.LightningDataModule):
    def __init__(self, train_dataset, test_dataset, batch_size=120):
        super().__init__()
        self.train_dataset = train_dataset
        self.test_dataset = test_dataset
        self.batch_size = batch_size

    def train_dataloader(self):
        return DataLoader(
            dataset=self.train_dataset, 
            batch_size=self.batch_size, 
            shuffle=True
        )

    def test_dataloader(self):
        return DataLoader(
            dataset=self.test_dataset, 
            batch_size=self.batch_size, 
            shuffle=False
        )



In [None]:
data_module = CustomDataModule(
    train_dataset=train_dataset,
    test_dataset=test_dataset,
    batch_size=250
)

In [None]:
def conv3x3(in_channels, out_channels, stride=1):
    return nn.Conv2d(in_channels, out_channels, kernel_size=3,
                     stride=stride, padding=1, bias=False)

def conv1x1(in_channels, out_channels, stride=1):
    return nn.Conv2d(in_channels, out_channels, kernel_size=1,
                     stride=stride, bias=False)

In [None]:
class BasicBlock(nn.Module):
  expansion = 1
  def __init__(self, in_channels, out_channels, stride=1, downsample=None):
    super(BasicBlock, self).__init__()
    self.conv1 = conv3x3(in_channels, out_channels, stride)
    self.bn1 = nn.BatchNorm2d(out_channels)
    self.relu = nn.ReLU(inplace=True)
    self.conv2 = conv3x3(out_channels, out_channels)
    self.bn2 = nn.BatchNorm2d(out_channels)
    self.downsample = downsample
    self.stride = stride

  def forward(self, x):
    identity = x
    out = self.conv1(x)
    out = self.bn1(out)
    out = self.relu(out)

    out = self.conv2(out)
    out = self.bn2(out)

    if self.downsample is not None:
        identity = self.downsample(identity)

    out += identity
    out = self.relu(out)

    return out

In [None]:
class Bottleneck(nn.Module):
  expansion = 1
  def __init__(self, in_channels, out_channels, stride=1, downsample=None):
    super(Bottleneck, self).__init__()
    self.conv1 = conv1x1(in_channels, out_channels, stride)
    self.bn1 = nn.BatchNorm2d(out_channels)
    self.relu = nn.ReLU(inplace=True)
    self.conv2 = conv3x3(out_channels, out_channels)
    self.bn2 = nn.BatchNorm2d(out_channels)
    self.conv3 = conv1x1(out_channels, out_channels * self.expansion)
    self.bn3 = nn.BatchNorm2d(out_channels * self.expansion)
    self.downsample = downsample
    self.stride = stride

  def forward(self, x):
    identity = x
    out = self.conv1(x)
    out = self.bn1(out)
    out = self.relu(out)

    out = self.conv2(out)
    out = self.bn2(out)
    out = self.relu(out)

    out = self.conv3(out)
    out = self.bn3(out)
    if self.downsample is not None:
        identity = self.downsample(identity)

    out += identity
    out = self.relu(out)

    return out



In [None]:
class ResNet(nn.Module):
  def __init__(self, block, layers, num_classes):
    super(ResNet, self).__init__()
    self.in_channels = 64
    self.conv1 = nn.Conv2d(3, self.in_channels, kernel_size=7, stride=2, padding=3, bias=False)
    self.bn1 = nn.BatchNorm2d(self.in_channels)
    self.relu = nn.ReLU(inplace=True)
    self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
    self.layer1 = self._make_layer(block, 64, layers[0])
    self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
    self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
    self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
    self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
    self.fc = nn.Linear(512 * block.expansion, num_classes)
    for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
  def _make_layer(self, block, out_channels, blocks, stride=1):
    downsample = None
    if stride != 1:

        downsample = nn.Sequential(
            conv1x1(self.in_channels, out_channels * block.expansion, 2),
            nn.BatchNorm2d(out_channels * block.expansion),
        )

    layers = []
    layers.append(block(self.in_channels, out_channels, stride, downsample))
    self.in_channels = out_channels * block.expansion
    for _ in range(1, blocks):
        layers.append(block(self.in_channels, out_channels))

    return nn.Sequential(*layers)

  def forward(self, x):
    x = self.conv1(x)
    x = self.bn1(x)
    x = self.relu(x)
    x = self.maxpool(x)

    x = self.layer1(x)
    x = self.layer2(x)
    x = self.layer3(x)
    x = self.layer4(x)

    x = self.avgpool(x)
    x = torch.flatten(x, 1)
    x = self.fc(x)

    return x


In [None]:
# model = ResNet(Bottleneck, [3, 4, 6, 3], 100)
# optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
# loss_fn = nn.CrossEntropyLoss()

In [None]:
# from tqdm import tqdm
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# model = model.to(device)

In [None]:
import torch
import torch.nn as nn
import pytorch_lightning as L
import torch.nn.functional as F
from torchmetrics import Accuracy
from tqdm import tqdm

class LitResNet(L.LightningModule):
    def __init__(self, model, num_classes, learning_rate=0.1):
        super().__init__()
        self.model = model
        self.num_classes = num_classes
        self.learning_rate = learning_rate
        self.loss_fn = nn.CrossEntropyLoss()
        
        # Initialize metrics
        self.train_accuracy = Accuracy(task='multiclass', num_classes=num_classes)
        self.test_accuracy = Accuracy(task='multiclass', num_classes=num_classes)
        
        # Save hyperparameters
        self.save_hyperparameters(ignore=['model'])

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.loss_fn(logits, y)
        
        # Calculate accuracy
        preds = torch.argmax(logits, dim=1)
        acc = self.train_accuracy(preds, y)
        
        # Log metrics
        self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True)
        self.log('train_acc', acc, on_step=True, on_epoch=True, prog_bar=True)
        
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.loss_fn(logits, y)
        
        preds = torch.argmax(logits, dim=1)
        acc = self.val_accuracy(preds, y)
        
        self.log('val_loss', loss, on_epoch=True, prog_bar=True)
        self.log('val_acc', acc, on_epoch=True, prog_bar=True)
        
        return loss

    def test_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.loss_fn(logits, y)
        
        preds = torch.argmax(logits, dim=1)
        acc = self.test_accuracy(preds, y)
        
        self.log('test_loss', loss, on_epoch=True)
        self.log('test_acc', acc, on_epoch=True)
        
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        return optimizer

    def on_train_epoch_end(self):
        # Log epoch-level metrics
        self.log('train_acc_epoch', self.train_accuracy.compute(), prog_bar=True)
        self.train_accuracy.reset()

    def on_test_epoch_end(self):
        self.log('test_acc_epoch', self.test_accuracy.compute(), prog_bar=True)
        self.test_accuracy.reset()

# Create and use the model
model = ResNet(Bottleneck, [3, 4, 6, 3], 100)
lit_model = LitResNet(model, num_classes=100, learning_rate=0.1)

# Example of how to train with PyTorch Lightning
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint

# Define callbacks
checkpoint_callback = ModelCheckpoint(
    monitor='val_acc',
    mode='max',
    save_top_k=1,
    filename='best-{epoch:02d}-{val_acc:.2f}'
)

# Initialize trainer
trainer = Trainer(
    max_epochs=10,
    callbacks=[checkpoint_callback],
    accelerator='gpu' if torch.cuda.is_available() else 'cpu',
    devices=1
)

# Train the model (assuming you have DataLoaders) 


In [None]:
trainer.fit(lit_model, datamodule=data_module)