Init

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import numpy as np
from torchvision import datasets, transforms
import time
from PIL import Image
from torch.utils.tensorboard import SummaryWriter
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#device = torch.device("cpu")
print(device)

cuda


Model Definition

In [2]:
class CNN(nn.Module):
    def __init__(self, inChannels=3, chunkSizeX = 64,chunkSizeY = 64):
        super(CNN, self).__init__()
        self.chunkSizeX = chunkSizeX
        self.chunkSizeY = chunkSizeY
        self.intermediateSize = 1024
        self.transfer = nn.Sigmoid()
        self.encoder = nn.Sequential(
            nn.Linear(in_features=3*chunkSizeX*chunkSizeY, out_features=self.intermediateSize),
            nn.ReLU(),
        )                
        self.flatten = nn.Flatten()  
        self.unflatten  = nn.Unflatten(1, (3,chunkSizeY,chunkSizeX))

        self.decoder = nn.Sequential(
            nn.Linear(in_features=self.intermediateSize, out_features=3*chunkSizeX*chunkSizeY),
            self.transfer,
        )
    
    def forward(self, x):
        out = self.flatten(x)
        out = self.encoder(out)
        out = self.decoder(out)
        out = self.unflatten(out)
        return out

model = CNN().to(device)

In [3]:
# Create Data Set
trainImage = Image.open("./test-cropped.png")
chunkSizeX = model.chunkSizeX
chunkSizeY = model.chunkSizeY
trainImageTensor = torch.tensor(np.array(trainImage)).to(device)/255.
print("Image Tensor Shape: " + str(trainImageTensor.shape))
width = trainImageTensor.shape[1]
height = trainImageTensor.shape[0]
print("Image size: " + str(width) + "x" + str(height))
xRemain = width % chunkSizeX
yRemain = height % chunkSizeY
print("Remains: " + str(xRemain) + " - " + str(yRemain))
trainImageTensor = trainImageTensor[0:height-yRemain, 0:width-xRemain,:]
width = trainImageTensor.shape[1]
height = trainImageTensor.shape[0]
print("trainImageTensor shape after crop= ",trainImageTensor.shape)
print("New Image size: " + str(width) + "x" + str(height))

xChunks = int((width)/chunkSizeX)
yChunks = int((height)/chunkSizeY)
chunks = xChunks * yChunks
print("xChunks: " + str(xChunks))
print("yChunks: " + str(yChunks))
print("Chunks: " + str(chunks))
stride4 = 1
stride3 = 3
stride2 = width*3
stride1 = chunkSizeX
stride0 = chunkSizeY*width*3
trainImageChunks = trainImageTensor.as_strided((yChunks, xChunks,chunkSizeY, chunkSizeX,3), (stride0,stride1, stride2, stride3, stride4)).flatten(0,1)
print(trainImageChunks.shape)
print("chunk shape = ",trainImageChunks[0].shape)

# for i in range(0, 32):
#     chunk = trainImageChunks[8192-32+i].cpu()
#     img = Image.fromarray(np.uint8(chunk*255)).save("./images/test-out_"+str(i)+".png")


Image Tensor Shape: torch.Size([2048, 4096, 3])
Image size: 4096x2048
Remains: 0 - 0
trainImageTensor shape after crop=  torch.Size([2048, 4096, 3])
New Image size: 4096x2048
xChunks: 32
yChunks: 16
Chunks: 512
torch.Size([512, 128, 128, 3])
chunk shape =  torch.Size([128, 128, 3])


In [6]:
#model.load_state_dict(torch.load("./model1.pth"))
#model.eval()

#optimizer = optim.Adam(model.parameters(), lr=0.005)
optimizer = optim.SGD(model.parameters(), lr=1, momentum=0.1)
writer = SummaryWriter()
# create the loss function
criterion = nn.MSELoss()
#criterion = nn.CrossEntropyLoss()
# train the model
batchSize = 32
epochs = 500
batchCount = int(chunks/batchSize)
print ("batchCount = ",batchCount)
for epoch in range(1, epochs+1):
    for batch in range(0, batchCount):
        batchStart = batch * batchSize
        batchEnd = batchStart + batchSize
        if(batchEnd > chunks):
            batchEnd = chunks
        batchTensor = trainImageChunks[batchStart:batchEnd].transpose(1,3).transpose(2,3)
        optimizer.zero_grad()
        output = model(batchTensor)
        loss = criterion(output, batchTensor)
        loss.backward()
        optimizer.step()
        #print("Batch: " + str(batch) + " Loss: " + str(loss.item()))    
    print("Epoch: " + str(epoch) + " Loss: " + str(loss.item()))    
    inputExamples = batchTensor[0:10,:,:,:]
    outputExamples = output[0:10,:,:,:]
    inputGrid = torchvision.utils.make_grid(inputExamples)
    outputGrid = torchvision.utils.make_grid(outputExamples)
    writer.add_scalar('Loss/train', loss.item(), epoch)
    writer.add_image('input', inputGrid, epoch)
    writer.add_image('output', outputGrid, epoch)
    writer.flush()

# save the model
torch.save(model.state_dict(), "./model2.pth")


batchCount =  16
Epoch: 1 Loss: 0.17043165862560272
Epoch: 2 Loss: 0.1704246699810028
Epoch: 3 Loss: 0.1704176664352417
Epoch: 4 Loss: 0.17041067779064178
Epoch: 5 Loss: 0.1704036444425583
Epoch: 6 Loss: 0.17039665579795837
Epoch: 7 Loss: 0.17038965225219727
Epoch: 8 Loss: 0.17038266360759735
Epoch: 9 Loss: 0.17037567496299744
Epoch: 10 Loss: 0.17036867141723633
Epoch: 11 Loss: 0.17036166787147522
Epoch: 12 Loss: 0.1703546941280365
Epoch: 13 Loss: 0.1703476756811142
Epoch: 14 Loss: 0.17034070193767548
Epoch: 15 Loss: 0.17033368349075317
Epoch: 16 Loss: 0.17032670974731445
Epoch: 17 Loss: 0.17031972110271454
Epoch: 18 Loss: 0.17031273245811462
Epoch: 19 Loss: 0.17030572891235352
Epoch: 20 Loss: 0.1702987402677536
Epoch: 21 Loss: 0.1702917516231537
Epoch: 22 Loss: 0.17028474807739258
Epoch: 23 Loss: 0.17027777433395386
Epoch: 24 Loss: 0.17027080059051514
Epoch: 25 Loss: 0.17026381194591522
Epoch: 26 Loss: 0.1702568233013153
Epoch: 27 Loss: 0.1702498495578766
Epoch: 28 Loss: 0.17024287581

Exception in thread Thread-5:
Traceback (most recent call last):
  File "F:\Python\Python310\lib\threading.py", line 1016, in _bootstrap_inner
    self.run()
  File "F:\Python\Python310\lib\site-packages\tensorboard\summary\writer\event_file_writer.py", line 233, in run
    self._record_writer.write(data)
  File "F:\Python\Python310\lib\site-packages\tensorboard\summary\writer\record_writer.py", line 40, in write
    self._writer.write(header + header_crc + data + footer_crc)
  File "F:\Python\Python310\lib\site-packages\tensorboard\compat\tensorflow_stub\io\gfile.py", line 766, in write
    self.fs.append(self.filename, file_content, self.binary_mode)
  File "F:\Python\Python310\lib\site-packages\tensorboard\compat\tensorflow_stub\io\gfile.py", line 160, in append
    self._write(filename, file_content, "ab" if binary_mode else "a")
  File "F:\Python\Python310\lib\site-packages\tensorboard\compat\tensorflow_stub\io\gfile.py", line 164, in _write
    with io.open(filename, mode, encodi

Epoch: 154 Loss: 0.16936883330345154
