In [None]:
# Define transformations for the datasets
transform_gtsrb = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.Grayscale(num_output_channels=1),
    transforms.ToTensor(),
])

# GTSRB dataset
gtsrb_train = datasets.GTSRB(root='./data', split='train', transform=transform_gtsrb, download=True)
gtsrb_test = datasets.GTSRB(root='./data', split='test', transform=transform_gtsrb, download=True)



In [None]:
# Define transformations for the datasets
transform_mnist = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.ToTensor(),
])

# MNIST dataset
mnist_train = datasets.MNIST(root='./data', train=True, transform=transform_mnist, download=True)
mnist_test = datasets.MNIST(root='./data', train=False, transform=transform_mnist, download=True)


In [None]:
# Calculate mean and standard deviation for GTSRB
imgs = torch.stack([img for img, _ in gtsrb_train])
mean = imgs.mean()
std = imgs.std()

print("Mean:", mean)
print("Standard Deviation:", std)


Mean: tensor(0.3223)
Standard Deviation: tensor(0.2604)


In [None]:
# Calculate mean and standard deviation for MNIST
imgs_mnist = torch.stack([img for img, _ in mnist_train])
mean_mnist = imgs_mnist.mean()
std_mnist = imgs_mnist.std()

print("Mean (MNIST):", mean_mnist)
print("Standard Deviation (MNIST):", std_mnist)


Mean (MNIST): tensor(0.1309)
Standard Deviation (MNIST): tensor(0.2893)


In [96]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

# Define transformations for the datasets
transform_gtsrb = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.Grayscale(num_output_channels=1),
    transforms.ToTensor(),
    transforms.Normalize((0.3223,), (0.2604,))
])

transform_mnist = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.ToTensor(),
    transforms.Normalize((0.1309,), (0.2893,))
])


# MNIST dataset
mnist_train = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
mnist_test = datasets.MNIST(root='./data', train=False, transform=transform, download=True)

# GTSRB dataset
gtsrb_train = datasets.GTSRB(root='./data', split='train', transform=transform, download=True)
gtsrb_test = datasets.GTSRB(root='./data', split='test', transform=transform, download=True)

# Data loaders
train_loader_mnist = DataLoader(dataset=mnist_train, batch_size=256, shuffle=True)
test_loader_mnist = DataLoader(dataset=mnist_test, batch_size=256, shuffle=False)

train_loader_gtsrb = DataLoader(dataset=gtsrb_train, batch_size=256, shuffle=True)
test_loader_gtsrb = DataLoader(dataset=gtsrb_test, batch_size=256, shuffle=False)


In [97]:
import random
import typing as tp

import torch


class MultiDataLoader:
    """
    from pnpl

    Provides batches randomly selected from multiple dataloaders. Ensures that nothing within a batch from
    one dataloader is ever mixed with data from another dataloader.

    on each iteration returns batch, dataset_name
    """

    def __init__(
        self, dataloaders: tp.List[torch.utils.data.Dataset], shuffle: bool = True, dataset_names = None
    ):
        """

        Keyword arguments:
        dataloaders -- list of dataloaders to randomly sample batches from
        shuffle -- whether to shuffle the order of random sampling on every __iter__ call
        """

        self.dataloaders = dataloaders
        self.dataset_names = dataset_names
        self.data_sizes = [len(dataloader) for dataloader in dataloaders]
        self.data_len = sum(self.data_sizes)
        self.batch_order = []
        self.shuffle = shuffle

        # If not randomly shuffling on each loop, define a random order on instantiation
        if not self.shuffle:
            self._reset()
            self.fixed_batch_order = (
                self.batch_order.copy()
            )  # Store a copy of this defined order

    def __iter__(self):
        if self.shuffle:
            self._reset()  # Generate a new random batch sampling order
        else:
            self.batch_order = self.fixed_batch_order.copy()  # Use the defined order
            self.dataloader_iters = [
                iter(dataloader) for dataloader in self.dataloaders
            ]  # reset dataloaders
        return self

    def __next__(self):
        # Return next sample from randomly selected iterable
        if not self.batch_order:
            raise StopIteration
        dl_idx = self.batch_order.pop(0)
        batch = next(self.dataloader_iters[dl_idx])
        name = self.dataset_names[dl_idx] if self.dataset_names is not None else dl_idx
        return batch, name

    def __len__(self):
        return self.data_len

    def _generate_batch_order(self):
        batch_order = []
        for i, data_size in enumerate(self.data_sizes):
            batch_order.extend([i for _ in range(data_size)])
        random.shuffle(batch_order)
        return batch_order

    def _reset(self):
        self.batch_order = self._generate_batch_order()
        self.dataloader_iters = [iter(dataloader) for dataloader in self.dataloaders]


In [98]:
multi_train_loader = MultiDataLoader([train_loader_mnist, train_loader_gtsrb], shuffle=True, dataset_names = ["mnist", "gtsrb"])
multi_test_loader = MultiDataLoader([test_loader_mnist, test_loader_gtsrb], shuffle=False, dataset_names = ["mnist", "gtsrb"])


In [None]:
# @title
print(mnist_train[0][0].shape)
print(gtsrb_train[0][0].shape)

import matplotlib.pyplot as plt
def imshow(img):
    img = img / 2 + 0.5  # unnormalize
    plt.imshow(img.numpy().squeeze(), cmap='gray')
    plt.show()

imshow(mnist_train[0][0])
imshow(gtsrb_train[0][0])

In [None]:
# @title
!pip install lightning

In [99]:
import lightning as pl
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms

class SimpleCNN(pl.LightningModule):
    def __init__(self, dataset="mnist"):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
        self.fc1 = nn.Linear(64 * 7 * 7, 128)
        self.mnist_head = nn.Linear(128, 10)
        self.gtsrb_head = nn.Linear(128, 43)
        self.speed_sign_head = nn.Linear(128, 9)
        self.dataset = dataset

    def forward(self, x, dataset_name = None):
        if dataset_name is None:
            dataset_name = self.dataset
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 64 * 7 * 7)
        x = F.relu(self.fc1(x))
        if dataset_name == 'mnist':
            x = self.mnist_head(x)
        elif dataset_name == 'gtsrb':
            x = self.gtsrb_head(x)
        elif dataset_name == 'speed_sign':
            x = self.speed_sign_head(x)

        return x

    def configure_optimizers(self):
        return optim.Adam(self.parameters(), lr=0.001)

    def training_step(self, batch, batch_idx):
        if(type(batch[1]) == str):
            dataset_name = batch[1]
            batch = batch[0]
        else:
            dataset_name = None
        inputs, labels = batch
        outputs = self(inputs, dataset_name)
        loss = F.cross_entropy(outputs, labels)
        self.log('train_loss', loss)
        return loss

    def validation_step(self, batch, batch_idx):
        if(type(batch[1]) == str):
            dataset_name = batch[1]
            batch = batch[0]
        else:
            dataset_name = None
        inputs, labels = batch
        outputs = self(inputs, dataset_name)
        loss = F.cross_entropy(outputs, labels)
        self.log('val_loss', loss)
        return loss



In [50]:
trainer = pl.Trainer(max_epochs=5)
mnist_model = SimpleCNN("mnist")
trainer.fit(mnist_model, train_loader_mnist, test_loader_mnist)

INFO: GPU available: True (cuda), used: True
INFO:lightning.pytorch.utilities.rank_zero:GPU available: True (cuda), used: True
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: HPU available: False, using: 0 HPUs
INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:lightning.pytorch.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO: 
  | Name       | Type      | Params | Mode 
-------------------------------------------------
0 | conv1      | Conv2d    | 320    | train
1 | conv2      | Conv2d    | 18.5 K | train
2 | pool       | MaxPool2d | 0      | train
3 | fc1        | Linear    | 401 K  | train
4 | mnist_head | Linear    | 1.3 K  | train
5 | gtsrb_head | Linear    | 5.5 K  | train
-------------------------------------------------
427 K     Trainable params
0         Non-trainable params
427 K     Total p

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

INFO: `Trainer.fit` stopped: `max_epochs=5` reached.
INFO:lightning.pytorch.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=5` reached.


In [None]:
trainer = pl.Trainer(max_epochs=5)
gtsrb_model = SimpleCNN(dataset="gtsrb")
trainer.fit(gtsrb_model, train_loader_gtsrb, test_loader_gtsrb)


In [None]:
multidata_model = SimpleCNN()
trainer = pl.Trainer(max_epochs=5)
trainer.fit(multidata_model, multi_train_loader, multi_test_loader)


In [None]:
multidata_model_mnist = SimpleCNN()
multidata_model_mnist.load_state_dict(multidata_model.state_dict())
multidata_model_mnist.dataset = "mnist"

In [100]:
trainer = pl.Trainer(max_epochs=1)
trainer.fit(multidata_model_mnist, train_loader_mnist, test_loader_mnist)

mnist_model_scratch = SimpleCNN("mnist")
trainer = pl.Trainer(max_epochs=1)
trainer.fit(mnist_model_scratch, train_loader_mnist, test_loader_mnist)


INFO: GPU available: True (cuda), used: True
INFO:lightning.pytorch.utilities.rank_zero:GPU available: True (cuda), used: True
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: HPU available: False, using: 0 HPUs
INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:lightning.pytorch.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO: 
  | Name       | Type      | Params | Mode 
-------------------------------------------------
0 | conv1      | Conv2d    | 320    | train
1 | conv2      | Conv2d    | 18.5 K | train
2 | pool       | MaxPool2d | 0      | train
3 | fc1        | Linear    | 401 K  | train
4 | mnist_head | Linear    | 1.3 K  | train
5 | gtsrb_head | Linear    | 5.5 K  | train
-------------------------------------------------
427 K     Trainable params
0         Non-trainable params
427 K     Total p

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

INFO: `Trainer.fit` stopped: `max_epochs=1` reached.
INFO:lightning.pytorch.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=1` reached.
INFO: GPU available: True (cuda), used: True
INFO:lightning.pytorch.utilities.rank_zero:GPU available: True (cuda), used: True
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: HPU available: False, using: 0 HPUs
INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:lightning.pytorch.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO: 
  | Name            | Type      | Params | Mode 
------------------------------------------------------
0 | conv1           | Conv2d    | 320    | train
1 | conv2           | Conv2d    | 18.5 K | train
2 | pool            | MaxPool2d | 0      | train
3 | fc1             | Linear    | 401 K  | train
4 | mnist_head      | Linear    | 1.3 K 

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

INFO: `Trainer.fit` stopped: `max_epochs=1` reached.
INFO:lightning.pytorch.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=1` reached.


In [58]:
gtsrb5_mnist1 = SimpleCNN("gtsrb")
gtsrb5_mnist1.load_state_dict(gtsrb_model.state_dict())
gtsrb5_mnist1.dataset = "mnist"
trainer = pl.Trainer(max_epochs=1)
trainer.fit(gtsrb5_mnist1, train_loader_mnist, test_loader_mnist)
trainer.validate(gtsrb5_mnist1, test_loader_mnist)

INFO: GPU available: True (cuda), used: True
INFO:lightning.pytorch.utilities.rank_zero:GPU available: True (cuda), used: True
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: HPU available: False, using: 0 HPUs
INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:lightning.pytorch.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO: 
  | Name       | Type      | Params | Mode 
-------------------------------------------------
0 | conv1      | Conv2d    | 320    | train
1 | conv2      | Conv2d    | 18.5 K | train
2 | pool       | MaxPool2d | 0      | train
3 | fc1        | Linear    | 401 K  | train
4 | mnist_head | Linear    | 1.3 K  | train
5 | gtsrb_head | Linear    | 5.5 K  | train
-------------------------------------------------
427 K     Trainable params
0         Non-trainable params
427 K     Total p

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

INFO: `Trainer.fit` stopped: `max_epochs=1` reached.
INFO:lightning.pytorch.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=1` reached.
INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:lightning.pytorch.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Validation: |          | 0/? [00:00<?, ?it/s]

[{'val_loss': 0.13648468255996704}]

In [59]:
mnist6 = SimpleCNN("mnist")
mnist6.load_state_dict(mnist_model.state_dict())
mnist6.dataset = "mnist"
trainer = pl.Trainer(max_epochs=1)
trainer.fit(mnist6, train_loader_mnist, test_loader_mnist)
trainer.validate(mnist6, test_loader_mnist)

INFO: GPU available: True (cuda), used: True
INFO:lightning.pytorch.utilities.rank_zero:GPU available: True (cuda), used: True
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: HPU available: False, using: 0 HPUs
INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:lightning.pytorch.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO: 
  | Name       | Type      | Params | Mode 
-------------------------------------------------
0 | conv1      | Conv2d    | 320    | train
1 | conv2      | Conv2d    | 18.5 K | train
2 | pool       | MaxPool2d | 0      | train
3 | fc1        | Linear    | 401 K  | train
4 | mnist_head | Linear    | 1.3 K  | train
5 | gtsrb_head | Linear    | 5.5 K  | train
-------------------------------------------------
427 K     Trainable params
0         Non-trainable params
427 K     Total p

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

INFO: `Trainer.fit` stopped: `max_epochs=1` reached.
INFO:lightning.pytorch.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=1` reached.
INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:lightning.pytorch.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Validation: |          | 0/? [00:00<?, ?it/s]

[{'val_loss': 0.02908841334283352}]

In [57]:
trainer.validate(mnist_model_scratch, test_loader_mnist)
trainer.validate(multidata_model_mnist, test_loader_mnist)

INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:lightning.pytorch.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Validation: |          | 0/? [00:00<?, ?it/s]

INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:lightning.pytorch.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Validation: |          | 0/? [00:00<?, ?it/s]

[{'val_loss': 0.031113268807530403}]

In [66]:
# Filter GTSRB dataset for labels 0-8
speed_sign_train = [(img, target) for img, target in gtsrb_train if target < 9]
speed_sign_test = [(img, target) for img, target in gtsrb_test if target < 9]

# Create new datasets from the filtered data
speed_sign_train_dataset = torch.utils.data.TensorDataset(
    torch.stack([img for img, _ in speed_sign_train]),
    torch.tensor([target for _, target in speed_sign_train])
)
speed_sign_test_dataset = torch.utils.data.TensorDataset(
    torch.stack([img for img, _ in speed_sign_test]),
    torch.tensor([target for _, target in speed_sign_test])
)

train_loader_speed_sign = DataLoader(dataset=speed_sign_train_dataset, batch_size=256, shuffle=True)
test_loader_speed_sign = DataLoader(dataset=speed_sign_test_dataset, batch_size=256, shuffle=False)

In [67]:
multi_ss_mnist_train_dataloader = MultiDataLoader([train_loader_mnist, train_loader_speed_sign], shuffle=True, dataset_names = ["mnist", "speed_sign"])
multi_ss_mnist_test_dataloader = MultiDataLoader([test_loader_mnist, test_loader_speed_sign], shuffle=False, dataset_names = ["mnist", "speed_sign"])

In [102]:
# prompt: create a new cnn train on the multidataloader for one epoch and then mnist for one epoch and compare it to mnist for two epochs

multidata_ss_mnist_model = SimpleCNN()
trainer = pl.Trainer(max_epochs=1)
trainer.fit(multidata_ss_mnist_model, multi_ss_mnist_train_dataloader, multi_ss_mnist_test_dataloader)

multidata_ss_mnist_model.dataset = "mnist"
trainer = pl.Trainer(max_epochs=1)
trainer.fit(multidata_ss_mnist_model, train_loader_mnist, test_loader_mnist)

mnist_model_2_epochs = SimpleCNN("mnist")
trainer = pl.Trainer(max_epochs=2)
trainer.fit(mnist_model_2_epochs, train_loader_mnist, test_loader_mnist)

# Compare the models
trainer.validate(multidata_ss_mnist_model, test_loader_mnist)
trainer.validate(mnist_model_2_epochs, test_loader_mnist)


INFO: GPU available: True (cuda), used: True
INFO:lightning.pytorch.utilities.rank_zero:GPU available: True (cuda), used: True
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: HPU available: False, using: 0 HPUs
INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:lightning.pytorch.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO: 
  | Name            | Type      | Params | Mode 
------------------------------------------------------
0 | conv1           | Conv2d    | 320    | train
1 | conv2           | Conv2d    | 18.5 K | train
2 | pool            | MaxPool2d | 0      | train
3 | fc1             | Linear    | 401 K  | train
4 | mnist_head      | Linear    | 1.3 K  | train
5 | gtsrb_head      | Linear    | 5.5 K  | train
6 | speed_sign_head | Linear    | 1.2 K  | train
------------------------------------

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

INFO: `Trainer.fit` stopped: `max_epochs=1` reached.
INFO:lightning.pytorch.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=1` reached.
INFO: GPU available: True (cuda), used: True
INFO:lightning.pytorch.utilities.rank_zero:GPU available: True (cuda), used: True
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: HPU available: False, using: 0 HPUs
INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:lightning.pytorch.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO: 
  | Name            | Type      | Params | Mode 
------------------------------------------------------
0 | conv1           | Conv2d    | 320    | train
1 | conv2           | Conv2d    | 18.5 K | train
2 | pool            | MaxPool2d | 0      | train
3 | fc1             | Linear    | 401 K  | train
4 | mnist_head      | Linear    | 1.3 K 

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

INFO: `Trainer.fit` stopped: `max_epochs=1` reached.
INFO:lightning.pytorch.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=1` reached.
INFO: GPU available: True (cuda), used: True
INFO:lightning.pytorch.utilities.rank_zero:GPU available: True (cuda), used: True
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: HPU available: False, using: 0 HPUs
INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:lightning.pytorch.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO: 
  | Name            | Type      | Params | Mode 
------------------------------------------------------
0 | conv1           | Conv2d    | 320    | train
1 | conv2           | Conv2d    | 18.5 K | train
2 | pool            | MaxPool2d | 0      | train
3 | fc1             | Linear    | 401 K  | train
4 | mnist_head      | Linear    | 1.3 K 

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

INFO: `Trainer.fit` stopped: `max_epochs=2` reached.
INFO:lightning.pytorch.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=2` reached.
INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:lightning.pytorch.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Validation: |          | 0/? [00:00<?, ?it/s]

INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:lightning.pytorch.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Validation: |          | 0/? [00:00<?, ?it/s]

[{'val_loss': 0.05314750224351883}]

In [103]:
# Train on speed_sign for one epoch
speed_sign_model = SimpleCNN(dataset="speed_sign")
trainer = pl.Trainer(max_epochs=1)
trainer.fit(speed_sign_model, train_loader_speed_sign, test_loader_speed_sign)

# Fine-tune on MNIST for one epoch
speed_sign_model.dataset = "mnist"  # Switch to MNIST head
trainer = pl.Trainer(max_epochs=1)
trainer.fit(speed_sign_model, train_loader_mnist, test_loader_mnist)

trainer.validate(speed_sign_model, test_loader_mnist)


INFO: GPU available: True (cuda), used: True
INFO:lightning.pytorch.utilities.rank_zero:GPU available: True (cuda), used: True
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: HPU available: False, using: 0 HPUs
INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:lightning.pytorch.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO: 
  | Name            | Type      | Params | Mode 
------------------------------------------------------
0 | conv1           | Conv2d    | 320    | train
1 | conv2           | Conv2d    | 18.5 K | train
2 | pool            | MaxPool2d | 0      | train
3 | fc1             | Linear    | 401 K  | train
4 | mnist_head      | Linear    | 1.3 K  | train
5 | gtsrb_head      | Linear    | 5.5 K  | train
6 | speed_sign_head | Linear    | 1.2 K  | train
------------------------------------

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/usr/local/lib/python3.10/dist-packages/lightning/pytorch/loops/fit_loop.py:298: The number of training batches (35) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

INFO: `Trainer.fit` stopped: `max_epochs=1` reached.
INFO:lightning.pytorch.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=1` reached.
INFO: GPU available: True (cuda), used: True
INFO:lightning.pytorch.utilities.rank_zero:GPU available: True (cuda), used: True
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: HPU available: False, using: 0 HPUs
INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:lightning.pytorch.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO: 
  | Name            | Type      | Params | Mode 
------------------------------------------------------
0 | conv1           | Conv2d    | 320    | train
1 | conv2           | Conv2d    | 18.5 K | train
2 | pool            | MaxPool2d | 0      | train
3 | fc1             | Linear    | 401 K  | train
4 | mnist_head      | Linear    | 1.3 K 

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

INFO: `Trainer.fit` stopped: `max_epochs=1` reached.
INFO:lightning.pytorch.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=1` reached.
INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:lightning.pytorch.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Validation: |          | 0/? [00:00<?, ?it/s]

[{'val_loss': 0.10474296659231186}]

In [104]:
# prompt: train and validate on mnist for 1 epoch

# Assuming you want to train a new model from scratch on MNIST
mnist_model = SimpleCNN("mnist")
trainer = pl.Trainer(max_epochs=1)
trainer.fit(mnist_model, train_loader_mnist, test_loader_mnist)


INFO: GPU available: True (cuda), used: True
INFO:lightning.pytorch.utilities.rank_zero:GPU available: True (cuda), used: True
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: HPU available: False, using: 0 HPUs
INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:lightning.pytorch.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO: 
  | Name            | Type      | Params | Mode 
------------------------------------------------------
0 | conv1           | Conv2d    | 320    | train
1 | conv2           | Conv2d    | 18.5 K | train
2 | pool            | MaxPool2d | 0      | train
3 | fc1             | Linear    | 401 K  | train
4 | mnist_head      | Linear    | 1.3 K  | train
5 | gtsrb_head      | Linear    | 5.5 K  | train
6 | speed_sign_head | Linear    | 1.2 K  | train
------------------------------------

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

INFO: `Trainer.fit` stopped: `max_epochs=1` reached.
INFO:lightning.pytorch.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=1` reached.


In [105]:
trainer.validate(mnist_model, test_loader_mnist)

INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:lightning.pytorch.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Validation: |          | 0/? [00:00<?, ?it/s]

[{'val_loss': 0.08127802610397339}]