In [None]:
import torch
from torch.utils import tensorboard
import numpy as np
import matplotlib.pyplot as plt

In [None]:
torch.cuda.is_available()

In [None]:
import torchvision as tv

my_transforms = tv.transforms.Compose([
    tv.transforms.ToTensor(),
    tv.transforms.ConvertImageDtype(torch.float),
    tv.transforms.Normalize((0.5,), (0.5,))
])

mnist_digits = tv.datasets.MNIST('../../datasets', transform=my_transforms, download=True)

In [None]:
plt.imshow(mnist_digits[1000][0].squeeze())

In [None]:
class Digits(torch.utils.data.Dataset):
    def __init__(self, mnist_digits, digits = 'all'):
        """
        A little wrapper around MNIST dataset.
        
        mnist_digits: object returned from torchvision.datasets.MNIST
        digits: 'all' or a list of digits [1, 2, 3, ...] when interested in a subset of MNIST data
        returns:
            when used with dataloaders, returns items of the form {'sample': 1x28x28, 'label': digit}
        """
        
        
        try:
            self.samples = torch.load('mnist_samples.pt')
            self.labels = torch.load('mnist_labels.pt')
        
            if not self.samples.shape == (60000, 1, 28, 28) or not self.labels.shape == (60000, 1):
                raise
        except:
            mnist_digits = mnist_digits
            n = len(mnist_digits)
            assert(n == 60000)
            
            self.samples = torch.empty(n, 1, 28, 28)
            self.labels = torch.empty(n, 1)
        
            for i in range(n):
                self.samples[i,...] = mnist_digits[i][0]
                self.labels[i,] = mnist_digits[i][1]
            
            torch.save(self.samples, 'mnist_samples.pt')
            torch.save(self.labels, 'mnist_labels.pt')
        
        if digits == 'all':
            self.n = self.samples.shape[0]
        else:
            idxs = torch.full([self.samples.shape[0]], False)
            for i in digits:
                idxs = torch.logical_or(idxs, self.labels[:,0] == i)
            self.labels = self.labels[idxs]
            self.samples = self.samples[idxs]
        
    def __len__(self):
        return self.samples.shape[0]
    
    def __getitem__(self, idx):
        return {'sample': self.samples[idx], 'label': self.labels[idx]}

In [None]:
dataset = Digits(mnist_digits, digits=[5])
len(dataset)

In [None]:
i = np.random.choice(len(dataset))
plt.imshow(dataset[i]['sample'].squeeze())
plt.title(dataset[i]['label'].squeeze().numpy().astype(np.int32));

In [None]:
batch_size = 8
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

In [None]:
for a_batch_of_data in dataloader:
    print(len(a_batch_of_data['sample']))
    print(len(a_batch_of_data['label']))
    break

In [None]:
class AutoEncoder(torch.nn.Module):
    
    def __init__(self):
        super(AutoEncoder, self).__init__()
        
        self.encoder = torch.nn.Sequential(
            torch.nn.Conv2d(1, 16, 3, stride=3, padding=1),
            torch.nn.ReLU(True),
            torch.nn.MaxPool2d(2, stride=2),
            torch.nn.Conv2d(16, 8, 3, stride=2, padding=1),
            torch.nn.ReLU(True),
            torch.nn.MaxPool2d(2, stride=1)
        )
        
        self.decoder = torch.nn.Sequential(
            torch.nn.ConvTranspose2d(8, 16, 3, stride=2),
            torch.nn.ReLU(True),
            torch.nn.ConvTranspose2d(16, 8, 5, stride=3, padding=1),
            torch.nn.ReLU(True),
            torch.nn.ConvTranspose2d(8, 1, 2, stride=2, padding=1),
            torch.nn.Tanh()
        )
        
    def forward(self, x):
        x = self.encoder(x)
        #print(x.shape)
        
        x = self.decoder(x)
        return x

In [None]:
ae = AutoEncoder()

In [None]:
for data in dataloader:
    ae(data['sample'])
    break

In [None]:
class cudafy:
    
    def __init__(self, device=None):
        if torch.cuda.is_available() and device:
            self.device = device
        else:
            self.device = 0
    
    def name(self):
        if torch.cuda.is_available():
            return torch.cuda.get_device_name(self.device)
        return 'Cuda is not available.'
    
    def put(self, x):
        """Put x on the default cuda device."""
        if torch.cuda.is_available() and not x.is_cuda:
            return x.to(device=self.device)
        return x
    
    def __call__(self, x):
        return self.put(x)
    
    def get(self,x):
        """Get from cpu."""
        if x.is_cuda:
            return x.to(device='cpu')
        return x
    
def cpu(x):
    if x.is_cuda:
        return x.to(device='cpu')
    return x

In [None]:
print(torch.cuda.is_available())


In [None]:
gpu = cudafy()
model = gpu(AutoEncoder())

In [None]:
for data in dataloader:
    model(gpu(data['sample']))
    break

In [None]:
# for i in model.parameters():
#     print(i)

In [None]:
learning_rate = 1e-2
weight_decay = 1e-5
criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), 
                             lr=learning_rate, 
                             weight_decay=weight_decay)

In [None]:
losses = []

for epoch in range(10):
    print(f'Epoch = {epoch}')
    for data in dataloader:
        imgs = gpu(data['sample'])
        
        output = model(imgs) 
        loss = criterion(output, imgs)
        print(loss)
        
        losses.append(loss)
        
        print(loss)
        # #writer.add_scalar("Loss/train", loss, epoch)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        break

In [None]:
writer = tensorboard.SummaryWriter()

In [None]:
def train_model(iter):
    for epoch in range(iter):
        for data in dataloader:
            imgs = gpu(data['sample'])
            output = model(imgs) 
            loss = criterion(output, imgs)
            writer.add_scalar("Loss/train", loss.item(), epoch)
            img_grid = tv.utils.make_grid(imgs)
            writer.add_image("mnist", img_grid, epoch)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

In [None]:
train_model(10)

In [None]:
import torch
from torch.utils import tensorboard

writer = tensorboard.SummaryWriter()

In [None]:
0,1,2,3,4,5,6,7,8,9,8,7,6,5,4,3,2,1,0,1,2,3,4,5,6,7

In [None]:
torch.cuda.is_available()