Init

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import numpy as np
from torchvision import datasets, transforms
import time
from PIL import Image
from torch.utils.tensorboard import SummaryWriter
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#device = torch.device("cpu")
print(device)

cuda


Model Definition

In [2]:
class CNN(nn.Module):
    def __init__(self, inChannels=3, chunkSize = 32):
        super(CNN, self).__init__()
        self.chunkSize = chunkSize
        self.encoder = nn.Sequential(            
            nn.Linear(in_features=3*chunkSize*chunkSize, out_features=chunkSize*chunkSize),
            nn.ReLU(),                        
        )                
        self.flatten = nn.Flatten()  
        self.unflatten  = nn.Unflatten(1, (3,chunkSize,chunkSize))

        self.decoder = nn.Sequential(
            nn.Linear(in_features=chunkSize*chunkSize, out_features=3*chunkSize*chunkSize),
            nn.ReLU(),
        )
    
    def forward(self, x):
        x = self.flatten(x)
        x = self.encoder(x)
        out = self.decoder(x)
        out = self.unflatten(out)
        return out

model = CNN().to(device)

In [46]:
# Create Data Set
trainImage = Image.open("./test.png")
chunkSize = model.chunkSize
trainImageTensor = torch.tensor(np.array(trainImage)).to(device)
#Create Image Chunks

#for i in range(0, trainImageTensor.shape[0], chunkSize):
#    for j in range(0, trainImageTensor.shape[1], chunkSize):
#        chunk = trainImageTensor[i:i+chunkSize, j:j+chunkSize]
#        chunk = torch.transpose(chunk, 0, 2)/255.0
#        chunks.append(chunk)

print(trainImageTensor.shape)
xRemain = trainImageTensor.shape[0] % chunkSize
yRemain = trainImageTensor.shape[1] % chunkSize
trainImageTensor = trainImageTensor[0:trainImageTensor.shape[0]-xRemain, 0:trainImageTensor.shape[1]-yRemain]
print(trainImageTensor.shape)
#calculate count of chunks
yChunks = int((trainImageTensor.shape[0])/chunkSize)
xChunks = int((trainImageTensor.shape[1])/chunkSize)
chunks = xChunks * yChunks
print("xChunks: " + str(xChunks))
print("yChunks: " + str(yChunks))
# Change the shape of the tensor to be (chunks,3, 32, 32)
trainImageTensor = torch.transpose(trainImageTensor, 2, 0).transpose(2 ,1)/255.0
print("trainImageTensor = ",trainImageTensor.shape)
#trainImageChunks = trainImageTensor.reshape(chunks,3,chunkSize,chunkSize)
stride4 = 1
stride3 = trainImageTensor.shape[2]
stride2 = chunkSize
stride1 = chunkSize * trainImageTensor.shape[1]
stride0 = trainImageTensor.shape[1] * trainImageTensor.shape[2]
trainImageChunks = trainImageTensor.as_strided((yChunks, xChunks,3,chunkSize, chunkSize), (stride1, stride2,stride0, stride3, stride4)).flatten(0,1)
print(trainImageChunks.shape)
chunk = trainImageChunks[100].cpu().transpose(0,2)
print(chunk.shape)
img = Image.fromarray(np.uint8(chunk*255.0)).save("test-out.png")


torch.Size([4745, 6400, 3])
torch.Size([4736, 6400, 3])
xChunks: 200
yChunks: 148
trainImageTensor =  torch.Size([3, 4736, 6400])
torch.Size([29600, 3, 32, 32])
torch.Size([32, 32, 3])


In [50]:

optimizer = optim.Adam(model.parameters(), lr=0.0005)
writer = SummaryWriter()
# create the loss function
criterion = nn.MSELoss()
# train the model
batchSize = 100
epochs = 100
batchCount = int(chunks/batchSize)
for epoch in range(1, epochs+1):
    for batch in range(0, batchCount):
        batchStart = batch * batchSize
        batchEnd = batchStart + batchSize
        if(batchEnd > chunks):
            batchEnd = chunks
        batchTensor = trainImageChunks[batchStart:batchEnd]        
        optimizer.zero_grad()
        output = model(batchTensor)
        loss = criterion(output, batchTensor)
        loss.backward()
        optimizer.step()
        #print("Batch: " + str(batch) + " Loss: " + str(loss.item()))    
    print("Epoch: " + str(epoch) + " Loss: " + str(loss.item()))
    inputExamples = batchTensor[0:10,:,:,:]
    print(inputExamples.shape)
    outputExamples = output[0:10,:,:,:]
    inputGrid = torchvision.utils.make_grid(inputExamples)
    outputGrid = torchvision.utils.make_grid(outputExamples)
    writer.add_scalar('Loss/train', loss.item(), epoch)
    writer.add_image('input', inputGrid, epoch)
    writer.add_image('output', outputGrid, epoch)
    writer.flush()

# save the model
torch.save(model.state_dict(), "./model1.pth")


Epoch: 1 Loss: 0.027219856157898903
torch.Size([10, 3, 32, 32])
Epoch: 2 Loss: 0.027228614315390587
torch.Size([10, 3, 32, 32])
Epoch: 3 Loss: 0.027233509346842766
torch.Size([10, 3, 32, 32])
Epoch: 4 Loss: 0.02723878249526024
torch.Size([10, 3, 32, 32])
Epoch: 5 Loss: 0.027242975309491158
torch.Size([10, 3, 32, 32])
Epoch: 6 Loss: 0.02729019522666931
torch.Size([10, 3, 32, 32])
Epoch: 7 Loss: 0.027253275737166405
torch.Size([10, 3, 32, 32])
Epoch: 8 Loss: 0.027412772178649902
torch.Size([10, 3, 32, 32])
Epoch: 9 Loss: 0.027245713397860527
torch.Size([10, 3, 32, 32])
Epoch: 10 Loss: 0.027252744883298874
torch.Size([10, 3, 32, 32])
Epoch: 11 Loss: 0.027369966730475426
torch.Size([10, 3, 32, 32])
Epoch: 12 Loss: 0.027417905628681183
torch.Size([10, 3, 32, 32])
Epoch: 13 Loss: 0.027245497331023216
torch.Size([10, 3, 32, 32])
Epoch: 14 Loss: 0.027295926585793495
torch.Size([10, 3, 32, 32])
Epoch: 15 Loss: 0.027382301166653633
torch.Size([10, 3, 32, 32])
Epoch: 16 Loss: 0.02723708562552929
