In [31]:
import torch
import torch.nn as nn
import numpy as np 
import torch.nn.functional as F 

import visdom
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
vis = visdom.Visdom()

Setting up a new session...


In [33]:
import os
os.chdir("C:\\Users\\SAIL\\Documents\\GitHub\\all_about_torch")

In [None]:
class fxnnxc(nn.Module):
    def __init__(self):
        super(fxnnxc, self).__init__()
        self.in_channel = 1
        self.latent_dim = 64
        self.hidden_dims = [self.in_channel, 32, 64, 256]
        # build encoder
        modules = []
        #in_channel = self.in_channel
        for i in range(len(self.hidden_dims)-1):
            modules.append(
                nn.Sequential(
                    nn.Conv2d(self.hidden_dims[i], out_channels=self.hidden_dims[i+1],
                                kernel_size=3, stride=2, padding=1),
                    nn.BatchNorm2d(self.hidden_dims[i+1]),
                    nn.LeakyReLU()
                )
            )

        self.encoder = nn.Sequential(*modules)
        self.fc_mu  = nn.Linear(self.hidden_dims[-1]*16, self.latent_dim)
        self.fc_var = nn.Linear(self.hidden_dims[-1]*16, self.latent_dim)

        # build decoder
        modules = []
        self.decoder_input = nn.Linear(self.latent_dim, self.hidden_dims[-1]*16)
        for i in range(len(self.hidden_dims)-1,1,-1):
            modules.append(
                nn.Sequential(
                    nn.ConvTranspose2d(self.hidden_dims[i],
                                        self.hidden_dims[i-1],
                                        kernel_size=3,
                                        stride  = 2,
                                        padding=1,
                                        output_padding=1),
                    nn.BatchNorm2d(self.hidden_dims[i-1]),
                    nn.LeakyReLU()
                )
            )
        
        self.decoder = nn.Sequential(*modules)
        self.final_layer = nn.Sequential(
            nn.ConvTranspose2d(self.hidden_dims[1],
                                self.hidden_dims[1],
                                kernel_size=3,
                                stride=2,
                                padding=1,
                                output_padding=1),
            nn.BatchNorm2d(self.hidden_dims[1]),
            nn.LeakyReLU(),
            nn.Conv2d(self.hidden_dims[1], out_channels=self.hidden_dims[0],
                        kernel_size=3, padding=1),
            nn.Tanh())

    #encode
    def encode(self, input):
        result = self.encoder(input)
        result = torch.flatten(result, start_dim=1)
        mu = self.fc_mu(result)
        log_var = self.fc_var(result)

        return [mu, log_var]
    
    #decode
    def decode(self, z):
        result = self.decoder_input(z)
        result = result.view(-1, 256, 4,4)
        result = self.decoder(result)
        result = self.final_layer(result)
        return result
    
    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return eps * std + mu

    def forward(self, input):
        mu, log_var = self.encode(input)
        z = self.reparameterize(mu, log_var)

        return [self.decode(z), input, mu, log_var]

    def loss_function(self, args, **kwargs) -> dict:
        recons = args[0]
        input = args[1]
        mu = args[2]
        log_var = args[3]

        kld_weight = 32
        
        recons_loss = F.mse_loss(recons, input)
        kld_loss = torch.mean(-0.5 * torch.sum(1+ log_var - mu**2- log_var.exp(), dim=1), dim=0)

        loss = recons_loss + kld_weight * kld_loss 
        return {"loss":loss, "Reconsturciton_loss": recons_loss, "KLD":kld_loss}

    def sample(self, num_samples:int, current_device:int, **kwargs)->torch.Tensor:
        z = torch.randn(num_samples, self.latent_dim)
        z = z.to(current_device)

        samples = samples.decode(z)
        return samples
    
    def generate(self, x:torch.Tensor, **kwargs) ->torch.Tensor:
        return self.forward(x)[0]


In [30]:
#--------------------------------------------- file read
x_train = np.load("data/mnist_train.npy")
x_test  = np.load("data/mnist_test.npy")
y_train = np.load("data/mnist_train_target.npy")
y_test  = np.load("data/mnist_test_target.npy")
#--------------------------------------------- numpy to tensor
x_train  = torch.from_numpy(x_train).float()       #long으로 하면 loss 계산할 때 에러
x_test   = torch.from_numpy(x_test).float()
y_train  = torch.from_numpy(y_train).long()        #float으로 하면 loss 계산할 때 에러  
y_test   = torch.from_numpy(y_test).long()

#--------------------------------------------- data to dataset
train_dataset = torch.utils.data.TensorDataset(x_train, y_train)
test_dataset  = torch.utils.data.TensorDataset(x_test,  y_test)

#--------------------------------------------- dataset to dataloader 
train_loader = torch.utils.data.DataLoader(train_dataset,
                                          batch_size=32,
                                          shuffle=True,
                                          num_workers=2)
                                    
test_loader = torch.utils.data.DataLoader(train_dataset,
                                          batch_size=32,
                                          shuffle=True,
                                          num_workers=2)                                 
                                    

In [32]:
model = fxnnxc().to(device)
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

In [34]:
EPOCH = 3

for t in range(EPOCH): # EPOCH
    for i, (sample, target) in enumerate(train_loader): #BATCH
        sample = sample/255
        sample = torch.nn.functional.pad(sample, ((2,2,2,2)), 'constant')
        sample = sample.unsqueeze(dim=1)
        sample = sample.to(device)
        target = target.to(device)

        y = model(sample)
        loss = model.loss_function(y)
        # loss = criterion(y, target)
        optimizer.zero_grad()
        loss['loss'].backward()
        optimizer.step()
        
        if i % 1000 == 99:
            print(t, loss['loss'].item())  
    #--------------- VISDOM 
        vis.line(X=[i], Y=[loss['loss'].item()], 
            win="loss", 
            update="append", 
            name=f"epoch {t+1}",
            opts=dict(showlegend=True, title='loss'))   
        vis.line(X=[i], Y=[loss['Reconsturciton_loss'].item()], 
            win="reconstruction_loss", 
            update="append", 
            name=f"epoch {t+1}",
            opts=dict(showlegend=True, title="recon"))   
        vis.line(X=[i], Y=[loss['KLD'].item()], 
            win="KLDivergence", 
            update="append", 
            name=f"epoch {t+1}",
            opts=dict(showlegend=True, title="KLD"))             
    #---------------------------------------------------

0 13.93436336517334


KeyboardInterrupt: 

In [None]:
# Test
correct = 0
total = 0
with torch.no_grad(): 
    for data in test_loader:
        images, labels = data
        images = images.view(images.size()[0], -1)   
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

In [None]:
print(f'Accuracy of the network on the 10000 test images: {100 * correct / total}')