Init

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import numpy as np
from torchvision import datasets, transforms
import time
from PIL import Image
from torch.utils.tensorboard import SummaryWriter
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#device = torch.device("cpu")
print(device)

cuda


Model Definition

In [3]:
class CNN(nn.Module):
    def __init__(self, inChannels=3, chunkSizeX = 32,chunkSizeY = 32):
        super(CNN, self).__init__()
        self.chunkSizeX = chunkSizeX
        self.chunkSizeY = chunkSizeY
        self.intermediateSize = 3 * chunkSizeX *chunkSizeY
        self.transfer = nn.Sigmoid()
        self.encoder = nn.Sequential(
            nn.Linear(in_features=3*chunkSizeX*chunkSizeY, out_features=self.intermediateSize),
            nn.ReLU(),
        )                
        self.flatten = nn.Flatten()  
        self.unflatten  = nn.Unflatten(1, (3,chunkSizeY,chunkSizeX))

        self.decoder = nn.Sequential(
            nn.Linear(in_features=self.intermediateSize, out_features=3*chunkSizeX*chunkSizeY),
            self.transfer,
        )
    
    def forward(self, x):
        out = self.flatten(x)
        out = self.encoder(out)
        out = self.decoder(out)
        out = self.unflatten(out)
        return out

model = CNN().to(device)

In [4]:
# Create Data Set
trainImage = Image.open("./test-cropped.png")
chunkSizeX = model.chunkSizeX
chunkSizeY = model.chunkSizeY
trainImageTensor = torch.tensor(np.array(trainImage)).to(device)/255.
print("Image Tensor Shape: " + str(trainImageTensor.shape))
width = trainImageTensor.shape[1]
height = trainImageTensor.shape[0]
print("Image size: " + str(width) + "x" + str(height))
xRemain = width % chunkSizeX
yRemain = height % chunkSizeY
print("Remains: " + str(xRemain) + " - " + str(yRemain))
trainImageTensor = trainImageTensor[0:height-yRemain, 0:width-xRemain,:]
width = trainImageTensor.shape[1]
height = trainImageTensor.shape[0]
print("trainImageTensor shape after crop= ",trainImageTensor.shape)
print("New Image size: " + str(width) + "x" + str(height))

xChunks = int((width)/chunkSizeX)
yChunks = int((height)/chunkSizeY)
chunks = xChunks * yChunks
print("xChunks: " + str(xChunks))
print("yChunks: " + str(yChunks))
print("Chunks: " + str(chunks))
stride4 = 1
stride3 = 3
stride2 = width*3
stride1 = chunkSizeX
stride0 = chunkSizeY*width*3
trainImageChunks = trainImageTensor.as_strided((yChunks, xChunks,chunkSizeY, chunkSizeX,3), (stride0,stride1, stride2, stride3, stride4)).flatten(0,1)
print(trainImageChunks.shape)
print("chunk shape = ",trainImageChunks[0].shape)

for i in range(0, 32):
    chunk = trainImageChunks[8192-32+i].cpu()
    img = Image.fromarray(np.uint8(chunk*255)).save("./images/test-out_"+str(i)+".png")


Image Tensor Shape: torch.Size([2048, 4096, 3])
Image size: 4096x2048
Remains: 0 - 0
trainImageTensor shape after crop=  torch.Size([2048, 4096, 3])
New Image size: 4096x2048
xChunks: 128
yChunks: 64
Chunks: 8192
torch.Size([8192, 32, 32, 3])
chunk shape =  torch.Size([32, 32, 3])


In [21]:
#model.load_state_dict(torch.load("./model1.pth"))
#model.eval()

#optimizer = optim.Adam(model.parameters(), lr=0.005)
optimizer = optim.SGD(model.parameters(), lr=10, momentum=0.1)
writer = SummaryWriter()
# create the loss function
criterion = nn.MSELoss()
#criterion = nn.CrossEntropyLoss()
# train the model
batchSize = 32
epochs = 100
batchCount = int(chunks/batchSize)
print ("batchCount = ",batchCount)
for epoch in range(1, epochs+1):
    for batch in range(0, batchCount):
        batchStart = batch * batchSize
        batchEnd = batchStart + batchSize
        if(batchEnd > chunks):
            batchEnd = chunks
        batchTensor = trainImageChunks[batchStart:batchEnd].transpose(1,3).transpose(2,3)
        optimizer.zero_grad()
        output = model(batchTensor)
        loss = criterion(output, batchTensor)
        loss.backward()
        optimizer.step()
        #print("Batch: " + str(batch) + " Loss: " + str(loss.item()))    
    print("Epoch: " + str(epoch) + " Loss: " + str(loss.item()))    
    inputExamples = batchTensor[0:10,:,:,:]
    outputExamples = output[0:10,:,:,:]
    inputGrid = torchvision.utils.make_grid(inputExamples)
    outputGrid = torchvision.utils.make_grid(outputExamples)
    writer.add_scalar('Loss/train', loss.item(), epoch)
    writer.add_image('input', inputGrid, epoch)
    writer.add_image('output', outputGrid, epoch)
    writer.flush()

# save the model
torch.save(model.state_dict(), "./model2.pth")


batchCount =  256
Epoch: 1 Loss: 0.00040637419442646205
Epoch: 2 Loss: 0.0004094888863619417
Epoch: 3 Loss: 0.0004147384315729141
Epoch: 4 Loss: 0.0004154706548433751
Epoch: 5 Loss: 0.00041548986337147653
Epoch: 6 Loss: 0.0004138938966207206
Epoch: 7 Loss: 0.0004111786838620901
Epoch: 8 Loss: 0.0004088580026291311
Epoch: 9 Loss: 0.00040633039316162467
Epoch: 10 Loss: 0.0004040311323478818
Epoch: 11 Loss: 0.0004016394668724388
Epoch: 12 Loss: 0.00040006660856306553
Epoch: 13 Loss: 0.0003991770790889859
Epoch: 14 Loss: 0.00039898190880194306
Epoch: 15 Loss: 0.00039893988287076354
Epoch: 16 Loss: 0.0003987892996519804
Epoch: 17 Loss: 0.00039862707490101457
Epoch: 18 Loss: 0.0003984566719736904
Epoch: 19 Loss: 0.0003980707551818341
Epoch: 20 Loss: 0.0003974259307142347
Epoch: 21 Loss: 0.00039743143133819103
Epoch: 22 Loss: 0.0003979154280386865
Epoch: 23 Loss: 0.00039836502401158214
Epoch: 24 Loss: 0.0003987055388279259
Epoch: 25 Loss: 0.0003990711993537843
Epoch: 26 Loss: 0.00039967638440