In [1]:
import torch
import numpy as np
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from PIL import Image, ImageFile
import requests
from io import BytesIO
from IPython.display import display
import matplotlib.pyplot as plt
import torchvision.models as models
import time
import torchinfo
from torchvision.io import read_image
from torch.utils.data import Dataset
import time
import datetime

from pyutils.lazynoisedataset import LazyNoiseDataset
from pyutils.distdataset import DistDataset


from tqdm.autonotebook import tqdm
from torch.utils.data import DataLoader

# Select device
device = torch.device("cpu")
if torch.cuda.is_available():
    device = torch.device("cuda")
#elif torch.xpu.is_available():
#    device = torch.device("xpu")
    


print("Device: {}".format(device))


# utils
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

Device: cpu


  from tqdm.autonotebook import tqdm


In [2]:
# my utils
from pyutils.lazynoisedataset import LazyNoiseDataset
from pyutils.distdataset import DistDataset, SplitDataset

from torch.utils.data import DataLoader

# loading dataset
cxr_images = DistDataset("cheXpert/cxp_cxrs{:03}.pt")
noise_dataset = LazyNoiseDataset(cxr_images)
#noise_dataset = LazyNoiseDataset.from_distdataset_pickle("pyutils/pickles/cxrdataset.pickle")

train_noise_dataset = SplitDataset(noise_dataset, split_end=len(noise_dataset) * 0.8)
test_noise_dataset  = SplitDataset(noise_dataset, split_start=len(noise_dataset) * 0.8)


train_noise_dataloader  = DataLoader(train_noise_dataset, batch_size=32, num_workers=2)
test_noise_dataloader   = DataLoader(test_noise_dataset, batch_size=32, num_workers=2)


7727 datapoints from  cheXpert/cxp_cxrs000.pt
7660 datapoints from  cheXpert/cxp_cxrs001.pt
7753 datapoints from  cheXpert/cxp_cxrs002.pt
7617 datapoints from  cheXpert/cxp_cxrs003.pt
7728 datapoints from  cheXpert/cxp_cxrs004.pt
7657 datapoints from  cheXpert/cxp_cxrs005.pt
7756 datapoints from  cheXpert/cxp_cxrs006.pt
7712 datapoints from  cheXpert/cxp_cxrs007.pt
7669 datapoints from  cheXpert/cxp_cxrs008.pt
7626 datapoints from  cheXpert/cxp_cxrs009.pt
7666 datapoints from  cheXpert/cxp_cxrs010.pt
7709 datapoints from  cheXpert/cxp_cxrs011.pt
7754 datapoints from  cheXpert/cxp_cxrs012.pt
7678 datapoints from  cheXpert/cxp_cxrs013.pt
8039 datapoints from  cheXpert/cxp_cxrs014.pt
9512 datapoints from  cheXpert/cxp_cxrs015.pt
9506 datapoints from  cheXpert/cxp_cxrs016.pt
9508 datapoints from  cheXpert/cxp_cxrs017.pt
9430 datapoints from  cheXpert/cxp_cxrs018.pt
9400 datapoints from  cheXpert/cxp_cxrs019.pt
9376 datapoints from  cheXpert/cxp_cxrs020.pt
9314 datapoints from  cheXpert/cxp

In [11]:
# Define the autoencoder architecture
class Autoencoder(nn.Module):
    def __init__(self):
        super(Autoencoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2)
        ).to(device)
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(32, 16,
                               kernel_size=3,
                               stride=2,
                               padding=1,
                               output_padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(16, 1,
                               kernel_size=3, # unet uses 2?
                               stride=2,
                               padding=1,
                               output_padding=1),
            nn.Sigmoid()
        ).to(device)


    def forward(self, x, residual=False):
        if residual:
            return self.forward_residual(x)
        else:
            return self.forward_non_residual(x)

    def forward_residual(self, x):
        x = x.to(device)
        r = x.clone()
        r = self.forward_non_residual(r)
        return x + r


    def forward_non_residual(self, x):
        x = x.to(device)
        x = self.encoder(x)
        x = self.decoder(x)
        return x


m = Autoencoder()
m = m.to(device)
torchinfo.summary(m, noise_dataset[0][0].shape)



idx=0: loafing 0th datapoint from cheXpert/cxp_cxrs000.pt
torch.Size([1, 320, 320])
idx=0: loafing 0th datapoint from cheXpert/cxp_cxrs000.pt


Layer (type:depth-idx)                   Output Shape              Param #
Autoencoder                              [1, 320, 320]             --
├─Sequential: 1-1                        [32, 80, 80]              --
│    └─Conv2d: 2-1                       [16, 320, 320]            160
│    └─ReLU: 2-2                         [16, 320, 320]            --
│    └─MaxPool2d: 2-3                    [16, 160, 160]            --
│    └─Conv2d: 2-4                       [32, 160, 160]            4,640
│    └─ReLU: 2-5                         [32, 160, 160]            --
│    └─MaxPool2d: 2-6                    [32, 80, 80]              --
├─Sequential: 1-2                        [1, 320, 320]             --
│    └─ConvTranspose2d: 2-7              [16, 160, 160]            4,624
│    └─ReLU: 2-8                         [16, 160, 160]            --
│    └─ConvTranspose2d: 2-9              [1, 320, 320]             145
│    └─Sigmoid: 2-10                     [1, 320, 320]             --
Total p

In [15]:
# Options
NUM_EPOCHS = 10
NUM_IMG_EXAMPLES = 5
NUM_TESTING_EPOCHS = 10 # in addition to last

# Time and logging
start_time = time.time()
last_epoch_time = time.time()# to be updated

training_loss_hist = []
testing_loss_hist = []
testing_epoch_is = []



# Define the loss function and optimizer
criterion = nn.MSELoss()
optimizer = optim.Adam(m.parameters(), lr=0.001)

# Train the autoencoder

for epoch in range(NUM_EPOCHS):
    m.train()

    progress_log = epoch % (NUM_EPOCHS / NUM_TESTING_EPOCHS) == 0 or epoch == NUM_EPOCHS - 1

    train_loss_acc = 0.0

    for x_batch, y_batch in train_noise_dataloader:
        optimizer.zero_grad()
        output = m(x_batch)
        loss = criterion(y_batch.to(device), output)
        loss.backward()
        optimizer.step()
        train_loss_acc += loss.item()

    train_loss = train_loss_acc / len(train_noise_dataloader)
    training_loss_hist.append(train_loss)


    epoch_time = time.time() - last_epoch_time
    last_epoch_time = time.time()

    if progress_log:
        testing_epoch_is.append(epoch)
        m.eval()

        test_loss_acc = 0.0
        for x_test_batch, y_test_batch in test_noise_dataloader:
            output = m(x_test_batch.to(device))
            loss = criterion(y_test_batch.to(device), output)
            test_loss_acc += loss.item()

        test_loss = test_loss_acc / len(test_noise_dataloader)
        testing_loss_hist.append(test_loss)


        fig, axs = plt.subplots(NUM_IMG_EXAMPLES, 3, figsize=(8, 8))

        axs[0,0].set_title("noisy")
        axs[0,1].set_title("real")
        axs[0,2].set_title("restored")

        for i in range(NUM_IMG_EXAMPLES):
            x, y = test_noise_dataloader.dataset[i]
            o = m(x.to(device))

            axs[i,0].set_axis_off()
            axs[i,1].set_axis_off()
            axs[i,2].set_axis_off()

            axs[i,0].imshow(x.cpu().reshape((28,28)), cmap='gray')
            axs[i,1].imshow(y.cpu().reshape((28,28)), cmap='gray')
            axs[i,2].imshow(o.cpu().detach().numpy().reshape((28,28)), cmap='gray')

        plt.show()
        print(
"""
Epoch [{}/{}]
Loss - Train: {:.4f}   Test: {:.4f}
Time - since start: {:.1f}   this epoch: {:.1f}
(note: 'since start' does include total testing...)
""".format(epoch+1, NUM_EPOCHS, train_loss, test_loss, time.time() - start_time, epoch_time))


plt.plot( [i+1 for i in range(NUM_EPOCHS)], training_loss_hist, "b:o", label="Train Loss")
plt.plot([i+1 for i in testing_epoch_is], testing_loss_hist,  "r:o", label="Test Loss")
plt.title = "Loss over time"
plt.legend()
plt.show()


torch.save(m.state_dict(), "models/last_model")
torch.save(m.state_dict(), "models/all/[{}]-{}params".format(datetime.datetime.now().strftime("%d%b-%H.%M"), count_parameters(m)))






idx=0: loafing 0th datapoint from cheXpert/cxp_cxrs000.ptidx=32: loafing 32th datapoint from cheXpert/cxp_cxrs000.pt

idx=1: loafing 1th datapoint from cheXpert/cxp_cxrs000.pt
idx=33: loafing 33th datapoint from cheXpert/cxp_cxrs000.pt
idx=2: loafing 2th datapoint from cheXpert/cxp_cxrs000.pt
idx=34: loafing 34th datapoint from cheXpert/cxp_cxrs000.pt
idx=35: loafing 35th datapoint from cheXpert/cxp_cxrs000.pt
idx=3: loafing 3th datapoint from cheXpert/cxp_cxrs000.pt
idx=36: loafing 36th datapoint from cheXpert/cxp_cxrs000.pt
idx=4: loafing 4th datapoint from cheXpert/cxp_cxrs000.pt
idx=5: loafing 5th datapoint from cheXpert/cxp_cxrs000.ptidx=37: loafing 37th datapoint from cheXpert/cxp_cxrs000.pt

idx=6: loafing 6th datapoint from cheXpert/cxp_cxrs000.pt
idx=38: loafing 38th datapoint from cheXpert/cxp_cxrs000.pt
idx=7: loafing 7th datapoint from cheXpert/cxp_cxrs000.ptidx=39: loafing 39th datapoint from cheXpert/cxp_cxrs000.pt

idx=40: loafing 40th datapoint from cheXpert/cxp_cxrs000

KeyboardInterrupt: 