In [1]:
import torch
from torch import nn
import torchvision
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor, Normalize, Compose
import matplotlib.pyplot as plt
from tqdm import tqdm
dev = 'cuda' if torch.cuda.is_available() else 'cpu'

In [2]:
class AddGaussianNoise(object):
    def __init__(self, mean=0., std=1.):
        self.std = std
        self.mean = mean
        
    def __call__(self, tensor):
        return tensor + torch.randn(tensor.size()) * self.std + self.mean
    
    def __repr__(self):
        return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std)
    
    __str__ = __repr__

In [3]:
std_transform = Compose([
  ToTensor(),
  Normalize((0.1307), (0.3081)), # mean and standard deviation
  AddGaussianNoise(0., 0.6),
])

In [4]:
train_ds = MNIST(root='downloads', train=True, download=True, transform = std_transform)
valid_ds = MNIST(root='downloads', train=False, download=True, transform = std_transform)

In [5]:
train_loader = DataLoader(train_ds, batch_size=512, shuffle = True, drop_last = True)
valid_loader = DataLoader(valid_ds, batch_size=512, shuffle = True, drop_last = True)

In [6]:
def print_imgs(model=None):
    if (model):
      plt.imshow(model(torch.randn(512, 128).to(dev).cpu().detach()[0,0], cmap='gray'))
    else:
      data, label = next(iter(train_loader))
      print("Truth:", label[0])
      plt.imshow(data[0,0], cmap='gray')
print_imgs()
data, label = next(iter(train_loader))
print(torch.min(data[0]))

Truth: tensor(5)
tensor(-2.9933)


In [7]:
class Net(nn.Module):
  def __init__(self):
    super(Net, self).__init__()
    self.layers = nn.Sequential(
        nn.Flatten(),
        nn.Linear(28 * 28, 50),
        nn.ReLU(),
        nn.Linear(50, 50),
        nn.ReLU(),
        nn.Linear(50, 10),
        nn.Softmax(dim=1),
    )
  def forward(self, x):
    x = self.layers(x)
    return x

In [8]:
model = Net().to(dev)
optim = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
cross_entropy = torch.nn.CrossEntropyLoss().to(dev)

epochs = 3

In [9]:
def test(model, test_loader, cross_entropy):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(dev), target.to(dev)
            output = model(data)
            predict = cross_entropy(output.log(), target)
            test_loss += predict.item() # sum up batch loss
            pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    print('Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)'.format(
        test_loss, correct, len(test_loader.dataset),
    100. * correct / len(test_loader.dataset)))

In [10]:
epochs = 1000
# Lo entrene un ratito y tira a 87% con ruido con desviacion 0.4
for i in range(epochs):
  for batch_idx, (data, label) in enumerate(train_loader):
      data, label = data.to(dev), label.to(dev)

      optim.zero_grad()
        
      predictions = model(data)
      loss = cross_entropy(predictions, label)
        
      loss.backward()
      optim.step()
  test(model, valid_loader, cross_entropy)

Test set: Average loss: 0.0043, Accuracy: 1498/10000 (15%)
Test set: Average loss: 0.0042, Accuracy: 2256/10000 (23%)
Test set: Average loss: 0.0041, Accuracy: 2376/10000 (24%)
Test set: Average loss: 0.0039, Accuracy: 2733/10000 (27%)
Test set: Average loss: 0.0035, Accuracy: 3928/10000 (39%)
Test set: Average loss: 0.0033, Accuracy: 5157/10000 (52%)
Test set: Average loss: 0.0035, Accuracy: 5324/10000 (53%)
Test set: Average loss: 0.0035, Accuracy: 5383/10000 (54%)
Test set: Average loss: 0.0032, Accuracy: 6015/10000 (60%)
Test set: Average loss: 0.0032, Accuracy: 6855/10000 (69%)
Test set: Average loss: 0.0032, Accuracy: 6979/10000 (70%)
Test set: Average loss: 0.0028, Accuracy: 7142/10000 (71%)
Test set: Average loss: 0.0026, Accuracy: 7657/10000 (77%)
Test set: Average loss: 0.0027, Accuracy: 7796/10000 (78%)
Test set: Average loss: 0.0027, Accuracy: 7896/10000 (79%)
Test set: Average loss: 0.0028, Accuracy: 7966/10000 (80%)
Test set: Average loss: 0.0028, Accuracy: 7998/10000 (80

KeyboardInterrupt: 

In [14]:
torch.save(model.state_dict(), './checkpoint.pt')