Init

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import numpy as np
from torchvision import datasets, transforms
import time
from PIL import Image
from torch.utils.tensorboard import SummaryWriter
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#device = torch.device("cpu")
print(device)

cuda


Model Definition

In [3]:
class CNN(nn.Module):
    def __init__(self, inChannels=3, chunkSizeX = 64,chunkSizeY = 64):
        super(CNN, self).__init__()
        self.chunkSizeX = chunkSizeX
        self.chunkSizeY = chunkSizeY
        self.intermediateSize = 1024
        self.transfer = nn.Sigmoid()
        self.encoder = nn.Sequential(
            nn.Linear(in_features=3*chunkSizeX*chunkSizeY, out_features=self.intermediateSize),
            nn.ReLU(),
        )                
        self.flatten = nn.Flatten()  
        self.unflatten  = nn.Unflatten(1, (3,chunkSizeY,chunkSizeX))

        self.decoder = nn.Sequential(
            nn.Linear(in_features=self.intermediateSize, out_features=3*chunkSizeX*chunkSizeY),
            self.transfer,
        )
    
    def forward(self, x):
        out = self.flatten(x)
        out = self.encoder(out)
        out = self.decoder(out)
        out = self.unflatten(out)
        return out

model = CNN().to(device)

In [4]:
# Create Data Set
trainImage = Image.open("./test-cropped.png")
chunkSizeX = model.chunkSizeX
chunkSizeY = model.chunkSizeY
trainImageTensor = torch.tensor(np.array(trainImage)).to(device)/255.
print("Image Tensor Shape: " + str(trainImageTensor.shape))
width = trainImageTensor.shape[1]
height = trainImageTensor.shape[0]
print("Image size: " + str(width) + "x" + str(height))
xRemain = width % chunkSizeX
yRemain = height % chunkSizeY
print("Remains: " + str(xRemain) + " - " + str(yRemain))
trainImageTensor = trainImageTensor[0:height-yRemain, 0:width-xRemain,:]
width = trainImageTensor.shape[1]
height = trainImageTensor.shape[0]
print("trainImageTensor shape after crop= ",trainImageTensor.shape)
print("New Image size: " + str(width) + "x" + str(height))

xChunks = int((width)/chunkSizeX)
yChunks = int((height)/chunkSizeY)
chunks = xChunks * yChunks
print("xChunks: " + str(xChunks))
print("yChunks: " + str(yChunks))
print("Chunks: " + str(chunks))
stride4 = 1
stride3 = 3
stride2 = width*3
stride1 = chunkSizeX
stride0 = chunkSizeY*width*3
trainImageChunks = trainImageTensor.as_strided((yChunks, xChunks,chunkSizeY, chunkSizeX,3), (stride0,stride1, stride2, stride3, stride4))
print(trainImageChunks.shape)
trainImageChunks = trainImageChunks.flatten(0,1)
print(trainImageChunks.shape)
# stride4 = 1
# stride3 = 3
# stride2 = width*3

# stride1 = 3
# stride0 = width*3
# trainImageChunks2 = trainImageTensor.as_strided((chunkSizeY, chunkSizeX,chunkSizeY, chunkSizeX,3), (stride0,stride1, stride2, stride3, stride4))
# print(trainImageChunks2.shape)
# trainImageChunks2 = trainImageChunks2.flatten(0,1)
# print(trainImageChunks2.shape)
# print("chunk shape = ",trainImageChunks2[0].shape)

# for i in range(0, 32):
#     chunk = trainImageChunks2[i].cpu()
#     img = Image.fromarray(np.uint8(chunk*255)).save("./images/test-out_"+str(i)+".png")


Image Tensor Shape: torch.Size([2048, 4096, 3])
Image size: 4096x2048
Remains: 0 - 0
trainImageTensor shape after crop=  torch.Size([2048, 4096, 3])
New Image size: 4096x2048
xChunks: 64
yChunks: 32
Chunks: 2048
torch.Size([32, 64, 64, 64, 3])
torch.Size([2048, 64, 64, 3])


In [17]:
mnist = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transforms.ToTensor())
print (mnist[0][0].shape)

torch.Size([1, 28, 28])


In [2]:
import image_dataset1
dataset = image_dataset1.ImageDataset1()

print ("Dataset length: " + str(len(dataset)))

dataset[0][0].shape

Dataset length: 2048


torch.Size([3, 64, 64])

In [5]:
#model.load_state_dict(torch.load("./model2.pth"))
#model.eval()

optimizer = optim.Adam(model.parameters(), lr=0.001)
#optimizer = optim.SGD(model.parameters(), lr=10, momentum=0.1)
writer = SummaryWriter()
# create the loss function
criterion = nn.MSELoss()
#criterion = nn.CrossEntropyLoss()
# train the model
batchSize = 32
epochs = 15000
batchCount = int(chunks/batchSize)
print ("batchCount = ",batchCount)
for epoch in range(1, epochs+1):
    for batch in range(0, batchCount):
        batchStart = batch * batchSize
        batchEnd = batchStart + batchSize
        if(batchEnd > chunks):
            batchEnd = chunks
        batchTensor = trainImageChunks[batchStart:batchEnd].transpose(1,3).transpose(2,3)
        optimizer.zero_grad()
        output = model(batchTensor)
        loss = criterion(output, batchTensor)
        loss.backward()
        optimizer.step()
        #print("Batch: " + str(batch) + " Loss: " + str(loss.item()))    
    print("Epoch: " + str(epoch) + " Loss: " + str(loss.item()))    
    inputExamples = batchTensor[0:10,:,:,:]
    outputExamples = output[0:10,:,:,:]
    inputGrid = torchvision.utils.make_grid(inputExamples)
    outputGrid = torchvision.utils.make_grid(outputExamples)
    writer.add_scalar('Loss/train', loss.item(), epoch)
    writer.add_image('input', inputGrid, epoch)
    writer.add_image('output', outputGrid, epoch)
    writer.flush()

# save the model
torch.save(model.state_dict(), "./model2-1.pth")


batchCount =  64
Epoch: 1 Loss: 0.11664383113384247
Epoch: 2 Loss: 0.0977165549993515
Epoch: 3 Loss: 0.09124251455068588
Epoch: 4 Loss: 0.07582320272922516
Epoch: 5 Loss: 0.06403671950101852
Epoch: 6 Loss: 0.06022229790687561
Epoch: 7 Loss: 0.05717259645462036
Epoch: 8 Loss: 0.05447704717516899
Epoch: 9 Loss: 0.050861090421676636
Epoch: 10 Loss: 0.04822021722793579
Epoch: 11 Loss: 0.04596956819295883
Epoch: 12 Loss: 0.04458176717162132
Epoch: 13 Loss: 0.04307396337389946
Epoch: 14 Loss: 0.041118696331977844
Epoch: 15 Loss: 0.03961924836039543
Epoch: 16 Loss: 0.03816228359937668
Epoch: 17 Loss: 0.037387385964393616
Epoch: 18 Loss: 0.03641321510076523
Epoch: 19 Loss: 0.03591479733586311
Epoch: 20 Loss: 0.035255447030067444
Epoch: 21 Loss: 0.034693680703639984
Epoch: 22 Loss: 0.03441593796014786
Epoch: 23 Loss: 0.03466454893350601
Epoch: 24 Loss: 0.034462008625268936
Epoch: 25 Loss: 0.03327658399939537
Epoch: 26 Loss: 0.03318615257740021
Epoch: 27 Loss: 0.03357147425413132
Epoch: 28 Loss: