In [2]:
import os
from torch import optim, nn, utils, Tensor
from torchvision import models, transforms
from torchvision.datasets import FashionMNIST
from torch.utils.data import DataLoader
import lightning as L
from lightning.pytorch.loggers import WandbLogger

In [3]:
wandb_logger = WandbLogger(
    project="fashion-mnist",   # Nazwa projektu w Wandb
    name="resnet18-transfer-learning", # Nazwa eksperymentu
    log_model=True # Logowanie architektury 
)

In [5]:
class FashionMNISTDataModule(L.LightningDataModule):
    def __init__(self, batch_size=32):
        super().__init__()
        self.batch_size = batch_size
        self.transform = transforms.Compose([
            transforms.Grayscale(num_output_channels=3),  # ResNet wymaga 3 kanałów
            transforms.Resize((224, 224)),                # ResNet wymaga większego obrazu (w fashionMNIST jest 28x28)
            transforms.ToTensor()
        ])

    def setup(self):
        self.train_dataset = FashionMNIST(
            root=os.getcwd(), train=True, download=True, transform=self.transform
        )
        self.val_dataset = FashionMNIST(
            root=os.getcwd(), train=False, download=True, transform=self.transform
        )
        self.test_dataset = FashionMNIST(root=os.getcwd(),train=False, download=True, transform=self.transform)

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size)

    def test_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=self.batch_size)

In [7]:
class LitModel(L.LightningModule):
    def __init__(self):
        super().__init__()
        self.model = models.resnet18(pretrained=True)
        self.model.fc = nn.Linear(self.model.fc.in_features, 10)

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.model(x)
        loss = nn.CrossEntropyLoss()(y_hat, y)
        acc = self.accuracy(y_hat, y)

        self.log("train_loss", loss, on_epoch=True, prog_bar=True)
        self.log("train_acc", acc, on_epoch=True, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.model(x)
        loss = nn.CrossEntropyLoss()(y_hat, y)
        acc = (y_hat.argmax(dim=1) == y).float().mean()

        self.log("val_loss", loss, on_epoch=True, prog_bar=True)
        self.log("val_acc", acc, on_epoch=True, prog_bar=True)
        return loss

    def test_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.model(x)
        acc = (y_hat.argmax(dim=1) == y).float().mean()

        self.log("val_acc", acc, on_epoch=True, prog_bar=True)
      
        return acc

    def on_train_epoch_end(self):
        loss = self.trainer.callback_metrics.get("train_loss")
        if loss is not None:
            print(f"Train Loss after epoch {self.current_epoch}: {loss:.4f}")

    def configure_optimizers(self):
        return optim.Adam(self.parameters(), lr=1e-3)

In [8]:
data_module = FashionMNISTDataModule(batch_size=32)
model = LitModel()



In [None]:
trainer = L.Trainer(max_epochs=1, logger=wandb_logger)  # liczba epok
trainer.fit(model, data_module)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
[34m[1mwandb[0m: Paste an API key from your profile and hit enter: