In [1]:
import torch
import torch.nn as nn

In [2]:
from torchvision.datasets import MNIST
import torchvision.transforms.v2 as v2

my_transform = v2.Compose([
    #v2.ToTensor()
    v2.ToImage(), 
    v2.ToDtype(torch.float32, scale=True)
])
dataset = MNIST('data', download=True, transform=my_transform)

In [3]:
from torch.utils.data import DataLoader

data_loader = DataLoader(dataset, batch_size=64, shuffle=True)
for X_train, y_label in data_loader:
    print(X_train.shape, y_label.shape)
    break

torch.Size([64, 1, 28, 28]) torch.Size([64])


In [4]:

class Encoder(nn.Module):
    def __init__(self, encoder_dim):
        super(Encoder, self).__init__()
        self.encoder_cnn = nn.Sequential(
            nn.Conv2d(1, 32, 3, padding='same'), # 28
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 3, padding='same'), # 14
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(64, 128, 3, padding='same'), # 7
            #nn.BatchNorm2d(128),
            nn.ReLU(),
            #nn.MaxPool2d(2),
            #nn.Conv2d(128, 256, 3, padding='same'),
            #nn.BatchNorm2d(256),
            #nn.ReLU(),
            #nn.MaxPool2d(2),
            #nn.Conv2d(256, 512, 3, padding='same'),
            #nn.BatchNorm2d(512),
            #nn.ReLU(),
            #nn.MaxPool2d(2)
        )
        self.encoder_fc = nn.Sequential(
            nn.Linear(128*7*7, encoder_dim),
            nn.ReLU()
        )

    def forward(self, x):
        x = self.encoder_cnn(x)
        #print(x.shape)
        x = torch.flatten(x, start_dim=1)
        x = self.encoder_fc(x)
        return x

In [126]:
#image = torch.rand(1, 1, 224, 224)
image = torch.rand(1, 1, 28, 28)
encoder = Encoder(128)
encoder(image).shape

torch.Size([1, 128])

In [127]:
class Decoder(nn.Module):
    def __init__(self, encoder_dim):
        super(Decoder, self).__init__()
        self.decoder_input = nn.Sequential(
            nn.Linear(encoder_dim, 128*7*7),
            nn.ReLU()
        )
        self.decoder_cnn = nn.Sequential(
            #nn.ConvTranspose2d(512, 256, 3, stride=2, padding=1, output_padding=1),
            #nn.BatchNorm2d(256),
            #nn.ReLU(),
            #nn.ConvTranspose2d(256, 128, 3, stride=2, padding=1, output_padding=1),
            #nn.BatchNorm2d(128),
            #nn.ReLU(),
            nn.ConvTranspose2d(128, 64, 3, stride=2, padding=1, output_padding=1), # 14
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, 3, stride=2, padding=1, output_padding=1), # 28
            #nn.BatchNorm2d(32),
            nn.ReLU(),
            #nn.ConvTranspose2d(32, 1, 3, stride=2, padding=1, output_padding=1),
            #nn.ReLU(),            
        )
        
    def forward(self, x):
        x = self.decoder_input(x)
        x = torch.unflatten(x, 1, (128, 7, 7))
        #print(x.shape)
        x = self.decoder_cnn(x)
        return x
        

In [128]:
encoded_image = torch.rand(1, 128)
decoder = Decoder(128)
decoder (encoded_image).shape

torch.Size([1, 32, 28, 28])

In [129]:
for X_train, y_label in data_loader:
    print(X_train.shape, y_label.shape)
    break

torch.Size([64, 1, 28, 28]) torch.Size([64])


In [137]:
import torch.optim as optim
optimizer = optim.Adam([{'params': encoder.parameters()},
                        {'params': decoder.parameters()}],
                        lr=0.0001)
loss_fn = nn.MSELoss()
device = 'cuda' if torch.cuda.is_available() else 'cpu'
encoder = encoder.to(device)
decoder = decoder.to(device)

for X_train, y_label in data_loader:
    optimizer.zero_grad()
    #print(X_train.shape)
    images = X_train.to(device)
    encoded_outputs = encoder(images)
    decoded_outputs = decoder(encoded_outputs)
    loss = loss_fn(decoded_outputs, images)
    loss.backward()
    optimizer.step()

  return F.mse_loss(input, target, reduction=self.reduction)


In [5]:
dataset_train = MNIST('data', download=True, train=True, transform=my_transform)
dataset_test = MNIST('data', download=True, train=False, transform=my_transform)

In [6]:
import torch
from torchvision.datasets import MNIST
import torchvision.transforms as v2

def make_dataset_noisy(train, noise_factor=0.2):
    dataset = MNIST('data', download=True, train=train, transform=my_transform)

    noise = noise_factor * torch.randint(0, 256, dataset.data.shape, dtype=torch.float32)
    noisy_data = dataset.data.float() + noise
    noisy_data = torch.clamp(noisy_data, 0.0, 255.0).type(torch.uint8)
    dataset.data = noisy_data

    return dataset

In [7]:
dataset_noisy_train = make_dataset_noisy(train=True)
dataset_train[0][0].shape, dataset_noisy_train[0][0].shape

(torch.Size([1, 28, 28]), torch.Size([1, 28, 28]))

In [8]:
import matplotlib.pyplot as plt

# def visualize_mnist(img_origin, img_noisy):
#     for i in range(4):
#       print(img_origin[i][0].shape, img_noisy[i][0].shape)
#       plt.subplot(2, 4, i+1)
#       plt.imshow(img_origin[i][0].squeeze())
#       plt.subplot(2, 4, i+5)
#       plt.imshow(img_noisy[i][0].squeeze())

In [10]:
#visualize_mnist(dataset_train, dataset_noisy_train)

In [None]:
plt.imshow(dataset_train[0][0].squeeze())

<matplotlib.image.AxesImage at 0x1aad2e33650>