In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import pickle
import matplotlib.pyplot as plt

In [None]:
from torch.utils.data import Dataset
from sklearn.preprocessing import MinMaxScaler

class CIFAR10Train(Dataset):
  def __init__(self, path):
    dictionary = unpickle(path + 'data_batch_1')
    self.images = dictionary['data']
    self.images = self.images.reshape(-1,3,32,32)
    self.images = (torch.Tensor(self.images)*2)/255 -1
    
    #images = images.movedim(1,3)
    self.labels = dictionary['labels']

  def __len__(self):
    return self.images.shape[0]
 
  def __getitem__(self, index):
    return self.labels[index], self.images[index]

class CIFAR10Val(Dataset):
  def __init__(self, path):
    dictionary = unpickle(path + 'data_batch_2')
    self.images = dictionary['data']
    self.images = self.images.reshape(-1,3,32,32)
    self.images = (torch.Tensor(self.images)*2)/255 -1
    
    #images = images.movedim(1,3)
    self.labels = dictionary['labels']

  def __len__(self):
    return self.images.shape[0]
 
  def __getitem__(self, index):
    return self.labels[index], self.images[index]

class CIFAR10Test(Dataset):
  def __init__(self, path):
    dictionary = unpickle(path + 'test_batch')
    self.images = dictionary['data']
    self.images = self.images.reshape(-1,3,32,32)
    self.images = (torch.Tensor(self.images)*2)/255 -1
    
    #images = images.movedim(1,3)
    self.labels = dictionary['labels']

  def __len__(self):
    return self.images.shape[0]
 
  def __getitem__(self, index):
    return self.labels[index], self.images[index]


In [None]:
from torch.utils.data import DataLoader, random_split
batch_size = 32
dataset_train = CIFAR10Train('/content/cifar-10-batches-py/')
dataset_val = CIFAR10Val('/content/cifar-10-batches-py/')
dataset_test = CIFAR10Test('/content/cifar-10-batches-py/')

train_loader = DataLoader(dataset_train, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True,)
val_loader = DataLoader(dataset_val, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
test_loader = DataLoader(dataset_test, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)

In [None]:
class BaselineBB(nn.Module):

  def __init__(self, nclasses):
    super(MyNetSmall, self).__init__()

    self.nclasses = nclasses
    self.conv1 = nn.Conv2d(3, 64, 3, padding = 1)
    self.conv2 = nn.Conv2d(64, 64, 3, padding = 1)
    self.conv3 = nn.Conv2d(64, 64, 3, padding = 1)
    self.conv4 = nn.Conv2d(64, 64, 3, padding = 1)
    self.bn1 = torch.nn.BatchNorm2d(64)
    self.bn2 = torch.nn.BatchNorm2d(64)
    self.bn3 = torch.nn.BatchNorm2d(64)
    self.bn4 = torch.nn.BatchNorm2d(64)
    self.maxpool = torch.nn.MaxPool2d(2,2)
    self.relu = nn.ReLU()
    self.fc1 = nn.Linear(2048,128)
    self.fc2 = nn.Linear(128, self.nclasses)
    self.softmax = nn.Softmax()
    self.flatten = nn.Flatten()
     
  def forward(self, x):
    x = self.bn1(self.relu(self.conv1(x)))
    x = self.maxpool(x)
    x = self.bn2(self.relu(self.conv2(x)))
    x = self.maxpool(x)
    x = self.bn3(self.relu(self.conv3(x)))
    x = self.maxpool(x)
    x = self.bn4(self.relu(self.conv4(x)))
    x = self.maxpool(x)
    x = self.relu(self.fc1(self.flatten(x)))
    x = self.fc2(x)
    return x

In [None]:
import sys
import time
import numpy as np
def train(net, n_epochs, train_loader, val_loader):
  total_train = len(train_loader)*train_loader.batch_size
  total_val = len(val_loader)*val_loader.batch_size
  val_acc = []
  val_loss = []
  t0 = time.time()

  criterion = nn.CrossEntropyLoss()
  optimizer = torch.optim.Adam(net.parameters(), lr=1e-3)
  for epoch in range(n_epochs):
    net.train()
    running_loss, running_acc = 0.0, 0.0
    for i, data in enumerate(train_loader, 0): # Obtener batch
      labels = data[0].cuda()
      inputs = data[1].float().cuda()
      optimizer.zero_grad()
      outputs = net(inputs)
      loss = criterion(outputs, labels)
      loss.backward()
      optimizer.step()
      # Calcular accuracy sobre el conjunto de validación y almacenarlo
      # para hacer un plot después


      items = (i+1) * train_loader.batch_size
      running_loss += loss.item()
      max_prob, max_idx = torch.max(outputs.data, dim=1)
      running_acc += torch.sum(max_idx == labels).item()
      info = f'\rEpoch:{epoch+1}({items}/{total_train}), '
      info += f'Loss:{running_loss/items:02.5f}, '
      info += f'Train Acc:{running_acc/items*100:02.1f}%'
      sys.stdout.write(info)

    net.eval()
    running_acc = 0.0
    valid_loss = 0.0
    
    for i, data in enumerate(val_loader, 0):
      labels = data[0].cuda()
      inputs = data[1].float().cuda()
      with torch.no_grad():
        Y_pred = net(inputs)
      max_prob, max_idx = torch.max(Y_pred.data, dim=1)
      running_acc += torch.sum(max_idx == labels).item()
      loss = criterion(Y_pred, labels)
      # record validation loss
      valid_loss+= loss.item()

    val_acc.append(running_acc/total_val*100)  
    val_loss.append(valid_loss/total_val*100)  
    info = f', Val Acc:{running_acc/total_val*100:02.2f}%.\n'
    sys.stdout.write(info)
  t1 = time.time()
  return val_acc, val_loss, t1-t0