In [1]:
from typing import Any
%cd '/home/aris/projects/grab_exp'

%load_ext autoreload
%autoreload 1

from IPython.display import display

/home/aris/projects/grab_exp


In [2]:
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import os
import sys
from pathlib import Path
import pickle
from timeit import default_timer as timer
from tqdm.notebook import tqdm

In [3]:
def show_df(df: pd.DataFrame):
    display(df.head())
    print(df.shape)

In [23]:
import lightning as L
from pl_bolts.datamodules import CIFAR10DataModule
from pl_bolts.transforms.dataset_normalizations import cifar10_normalization

In [37]:
from torchvision import transforms

BATCH_SIZE = 16

transform = transforms.Compose([
    transforms.ToTensor(),
    cifar10_normalization(),
])

dm = CIFAR10DataModule(
    data_dir='data/external',
    batch_size=BATCH_SIZE,
    train_transforms=transform,
    test_transforms=transform,
)

In [38]:
import torchvision
from torch import nn


def create_model():
    model = torchvision.models.resnet18(num_classes=10)
    # model.conv1 = nn.Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    # model.maxpool = nn.Identity()
    return model

In [43]:
from lightning.pytorch import LightningModule
import torch.nn.functional as F
from torch.optim.lr_scheduler import OneCycleLR
from torchmetrics.functional import accuracy


class LitResnet(LightningModule):
    def __init__(self, lr=1e-3, wd=5e-4, **kwargs):
        super().__init__()

        self.save_hyperparameters()
        self.model = create_model()
        self.loss_fn = nn.CrossEntropyLoss()

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.loss_fn(logits, y)
        preds = torch.argmax(logits, dim=1)

        acc = accuracy(preds, y, task='multiclass')
        self.log('train_loss', loss, prog_bar=True,
                 logger=True)
        self.log('train_acc', acc, prog_bar=True,
                 logger=True)
        return loss

    def evaluate(self, batch, stage=None):
        x, y = batch
        logits = self(x)
        loss = self.loss_fn(logits, y)
        preds = torch.argmax(logits, dim=1)

        acc = accuracy(preds, y, task='multiclass')
        self.log(f'{stage}_loss', loss, prog_bar=True,
                 logger=True)
        self.log(f'{stage}_acc', acc, prog_bar=True,
                 logger=True)
        return loss

    def validation_step(self, batch, batch_idx):
        return self.evaluate(batch, 'val')

    def test_step(self, batch, batch_idx):
        return self.evaluate(batch, 'test')

    def configure_optimizers(self):
        optimizer = torch.optim.SGD(
            self.parameters(),
            lr=self.hparams.lr,
            momentum=0.9,
            weight_decay=self.hparams.wd,
        )

        return optimizer

In [44]:
from lightning.pytorch.loggers import CSVLogger
from lightning.pytorch.callbacks import LearningRateMonitor, TQDMProgressBar

model = LitResnet()

trainer = L.Trainer(
    max_epochs=30,
    logger=CSVLogger('logs', name='resnet18'),
    callbacks=[
        LearningRateMonitor(logging_interval='step'),
        TQDMProgressBar(),
    ],
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [45]:
trainer.fit(model, datamodule=dm)

Files already downloaded and verified
Files already downloaded and verified


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name    | Type             | Params
---------------------------------------------
0 | model   | ResNet           | 11.2 M
1 | loss_fn | CrossEntropyLoss | 0     
---------------------------------------------
11.2 M    Trainable params
0         Non-trainable params
11.2 M    Total params
44.727    Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

TypeError: An invalid dataloader was passed to `Trainer.fit(train_dataloaders=...)`. Found <pl_bolts.datamodules.cifar10_datamodule.CIFAR10DataModule object at 0x7f0d9027b9a0>.