In [3]:
import torch
from torch import Tensor, nn, optim
import torch.nn.functional as F
import torchvision.transforms as T
from torchvision import datasets

from einops import rearrange

from torch.utils.data import DataLoader, Dataset, random_split

import matplotlib.pyplot as plt
import numpy as np

from tqdm.auto import tqdm

from torchmetrics import Accuracy, F1Score

In [28]:
DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

In [29]:
transform = T.Compose([
  T.ToTensor(),
  T.Normalize((0.5,), (0.5,)),
  T.RandomHorizontalFlip(),
])

train_dataset = datasets.FashionMNIST(
  root='./data',
  train=True,
  transform=transform,
  download=True,
)

test_dataset = datasets.FashionMNIST(
  root='./data',
  train=False,
  transform=transform,
  download=False,
)

# Crear dataloaders para iterar sobre los datasets
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, drop_last=True)
test_loader = DataLoader(test_dataset, batch_size=8, shuffle=False)

In [30]:
images, labels = next(iter(train_loader))
images.shape, labels.shape

(torch.Size([8, 1, 28, 28]), torch.Size([8]))

In [31]:
from m_down import DownBlock
from m_conv import Convolutional
from m_swing import SwinTransformer
from m_poolattention import PoolAttention

In [33]:
class Model(nn.Module):
  def __init__(self):
    super().__init__()
    self.conv = Convolutional(1, 64)
    self.vit1 = SwinTransformer(14, 64, 4, 4)
    self.down1 = DownBlock(64, 128)
    self.vit2 = SwinTransformer(7, 128, 4, 4)
    self.down2 = DownBlock(128, 256)
    self.pool = PoolAttention(256, False)
    self.output = nn.Linear(256, 10)

  def forward(self, x:Tensor) -> Tensor:
    x = self.conv(x)
    x = self.vit1(x.moveaxis(1, -1))
    x = self.down1(x.moveaxis(-1, 1))
    x = self.vit2(x.moveaxis(1, -1))
    x = self.down2(x.moveaxis(-1, 1))

    x = rearrange(x, 'b c h w -> b (h w) c')

    x = self.pool(x)
    x = self.output(F.silu(x))
    return x

model = Model().to(DEVICE)

In [34]:
optimizer = optim.AdamW(model.parameters(), 1e-4, weight_decay=1e-8)
loss_fn = nn.CrossEntropyLoss()
accuracy_metric = Accuracy(num_classes=10, average='weighted', task='multiclass')
f1_metric = F1Score(num_classes=10, average='weighted', task='multiclass')

In [35]:
def train(model:nn.Module, epoch:int):
  model.train()
  total_loss = 0
  total = 0
  accuracy_metric.reset()
  f1_metric.reset()

  bar = tqdm(total=len(train_loader), desc=f'Train ({epoch}/10)', leave=True, colour='blue')
  for images, labels in train_loader:
    images = images.to(DEVICE)
    labels = labels.to(DEVICE)

    logits = model(images)
    loss:Tensor = loss_fn(logits, labels)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    total += images.shape[0]
    total_loss += loss.item() * images.shape[0]
    preds = torch.argmax(logits, dim=1)
    accuracy_metric.update(preds, labels)
    f1_metric.update(preds, labels)

    bar.set_postfix(loss = loss.item())
    bar.update(1)

  total_loss /= total
  total_accuracy = accuracy_metric.compute().item()
  total_f1 = f1_metric.compute().item()

  bar.set_postfix(loss=total_loss, accuracy=total_accuracy, f1=total_f1)

  bar.update()
  bar.close()

In [37]:
def test(model:nn.Module, epoch:int):
  model.eval()
  total_loss = 0
  total = 0
  accuracy_metric.reset()
  f1_metric.reset()

  bar = tqdm(total=len(train_loader), desc=f'Test ({epoch}/10)', leave=True, colour='yellow')
  with torch.inference_mode():
    for images, labels in test_loader:
      images = images.to(DEVICE)
      labels = labels.to(DEVICE)

      logits = model(images)
      loss:Tensor = loss_fn(logits, labels)

      total += images.shape[0]
      total_loss += loss.item() * images.shape[0]
      preds = torch.argmax(logits, dim=1)
      accuracy_metric.update(preds, labels)
      f1_metric.update(preds, labels)

      bar.set_postfix(loss = loss.item())
      bar.update(1)

  total_loss /= total
  total_accuracy = accuracy_metric.compute().item()
  total_f1 = f1_metric.compute().item()

  bar.set_postfix(loss=total_loss, accuracy=total_accuracy, f1=total_f1)

  bar.update()
  bar.close()

In [None]:
for epoch in range(10):
  train(model, epoch+1)
  test(model, epoch+1)
