In [13]:
import numpy as np
import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.nn import init

from torchvision.models.resnet import BasicBlock, ResNet
from torchvision.transforms import ToTensor

In [14]:
import io
import torch.utils.data as data_utils
from PIL import Image
import os

import matplotlib.pyplot as plt
import torch.nn.functional as F
def default_loader(path):
    return Image.open(path)   

In [15]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision.models.resnet import BasicBlock, ResNet
from torch.nn import init
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms

# Custom convolutional layer
def conv(in_planes, out_planes, kernel_size=3, stride=1, dilation=1, bias=False, transposed=False):
    if transposed:
        layer = nn.ConvTranspose2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=1, output_padding=1,
                                   dilation=dilation, bias=bias)
    else:
        padding = (kernel_size + 2 * (dilation - 1)) // 2
        layer = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, bias=bias)
    if bias:
        init.constant_(layer.bias, 0)
    return layer

# Returns 2D batch normalization layer
def bn(planes):
    layer = nn.BatchNorm2d(planes)
    init.constant_(layer.weight, 1)
    init.constant_(layer.bias, 0)
    return layer

# Feature extraction using pretrained residual network
class FeatureResNet(ResNet):
    def __init__(self):
        super().__init__(BasicBlock, [3, 14, 16, 3], 1000)
        self.conv_f = conv(2, 64, kernel_size=3, stride=1)
        self.ReLu_1 = nn.ReLU(inplace=True)
        self.conv_pre = conv(512, 1024, stride=2, transposed=False)
        self.bn_pre = bn(1024)

    def forward(self, x):
        x1 = self.conv_f(x)  # Convolutional layer
        x = self.bn1(x1)
        x = self.relu(x)
        x2 = self.maxpool(x)  # Max pool layer
        x = self.layer1(x2)   # Residual layer 1
        x3 = self.layer2(x)   # Residual layer 2
        x4 = self.layer3(x3)  # Residual layer 3
        x5 = self.layer4(x4)  # Residual layer 4
        x6 = self.ReLu_1(self.bn_pre(self.conv_pre(x5)))  # Convolutional layer
        return x1, x2, x3, x4, x5, x6

# Segmentation network
class SegResNet(nn.Module):
    def __init__(self, num_classes, pretrained_net):
        super().__init__()
        self.pretrained_net = pretrained_net
        self.relu = nn.ReLU(inplace=True)
        self.conv3_2 = conv(1024, 512, stride=1, transposed=False)
        self.bn3_2 = bn(512)
        self.conv4 = conv(512, 512, stride=2, transposed=True)
        self.bn4 = bn(512)
        self.conv5 = conv(512, 256, stride=2, transposed=True)
        self.bn5 = bn(256)
        self.conv6 = conv(256, 128, stride=2, transposed=True)
        self.bn6 = bn(128)
        self.conv7 = conv(128, 64, stride=2, transposed=True)
        self.bn7 = bn(64)
        self.conv8 = conv(64, 64, stride=2, transposed=True)
        self.bn8 = bn(64)
        self.conv9 = conv(64, 32, stride=2, transposed=True)
        self.bn9 = bn(32)
        self.convadd = conv(32, 16, stride=1, transposed=False)
        self.bnadd = bn(16)
        self.conv10 = conv(16, num_classes, stride=2, kernel_size=5)
        init.constant_(self.conv10.weight, 0)  # Zero init

    def forward(self, x):
        x1, x2, x3, x4, x5, x6 = self.pretrained_net(x)  # Feature extraction
        
        x = self.relu(self.bn3_2(self.conv3_2(x6)))
        x = self.relu(self.bn4(self.conv4(x)))
        x = self.relu(self.bn5(self.conv5(x)))
        x = self.relu(self.bn6(self.conv6(x + x4)))
        x = self.relu(self.bn7(self.conv7(x + x3)))
        x = self.relu(self.bn8(self.conv8(x + x2)))
        x = self.relu(self.bn9(self.conv9(x + x1)))
        x = self.relu(self.bnadd(self.convadd(x)))
        x = self.conv10(x)
        return x


In [16]:
saved_model_path = 'trained_model/model/param_all_2_99_742024_07_25_12-34-03'

fnet = FeatureResNet()
fcn = SegResNet(2, fnet)
fcn = fcn.cpu()

fcn.load_state_dict(torch.load(saved_model_path))

<All keys matched successfully>

In [17]:
folder = 'google-dataset/speckle-image-dataset/'
test_count = 512
train_count = 4096
            
test_set = []
for z in range(0, test_count):
    test_set.append((folder + 'imgs3/train_image_' + str(z+1)+'_1.png',
                       folder + 'imgs3/train_image_' + str(z+1)+'_2.png',
                       folder + 'gt3/train_image_' + str(z+1)+'.mat'))
            
train_set = []
for z in range(test_count, train_count):
    train_set.append((folder + 'imgs3/train_image_' + str(z+1)+'_1.png',
                       folder + 'imgs3/train_image_' + str(z+1)+'_2.png',
                       folder + 'gt3/train_image_' + str(z+1)+'.mat'))
    

In [18]:
print(train_set[0])

('google-dataset/speckle-image-dataset/imgs3/train_image_11_1.png', 'google-dataset/speckle-image-dataset/imgs3/train_image_11_2.png', 'google-dataset/speckle-image-dataset/gt3/train_image_11.mat')


In [19]:
import scipy.io as sio
import random

class MyDataset(data_utils.Dataset):
    def __init__(self, dataset, transform=None, target_transform=None, loader=default_loader):
        self.imgs = dataset
        self.transform = transform
        self.target_transform = target_transform
        self.loader = loader

    def random_transform(self, img):
        # Initialize the transformation sequence
        transform_seq = transforms.Compose([
            # Apply random rotation with a probability
            # transforms.RandomApply([
            #     transforms.RandomRotation(degrees=(-10, 10))
            # ], p=0.5),

            # Apply random zoom with a probability of 0.5
            # transforms.RandomApply([
            #     transforms.RandomAffine(degrees=0, scale=(0.95, 1.05))
            # ], p=0.5),

            # Randomly adjust brightness, contrast, and saturation
            transforms.RandomApply([
                transforms.ColorJitter(brightness=(0.5, 1.5), contrast=(0.5, 1.5), saturation=(0.5, 1.5))
            ], p=0.5),
        ])
        
        # Apply the transformation sequence
        return transform_seq(img)

    def __getitem__(self, index):
        label_x, label_y, label_z = self.imgs[index]
        img1 = self.loader(label_x)
        img1 = self.random_transform(img1)
        img_1 = ToTensor()(img1.resize((128,128)))
        img_1 = img_1[::4,:,:]

        img2 = self.loader(label_y)
        img2 = self.random_transform(img2)
        img_2 = ToTensor()(img2.resize((128,128)))
        img_2 = img_2[::4,:,:]
        
        imgs = torch.cat((img_1, img_2), 0)
        
        try:
            gt = sio.loadmat(label_z)['Disp_field_1'].astype(float)
        except KeyError:
            gt = sio.loadmat(label_z)['Disp_field_2'].astype(float)
        
        gt = torch.tensor(gt).permute(2, 0, 1)  # Ensure gt is [C, H, W]
        gt = nn.functional.interpolate(gt.unsqueeze(0), size=(128, 128), mode='bilinear', align_corners=False).squeeze(0)  # Resize to [C, 128, 128]
        
        return imgs, gt

    def __len__(self):
        return len(self.imgs)

In [20]:
##TROUBLESHOOTING
import torch.optim as optim

In [21]:
# a simple custom collate function, just to show the idea
def my_collate(batch):
    data = [item[0] for item in batch]
    target = [item[1] for item in batch]
    
    # Check if all images have the same size
    sizes_data = [img.size() for img in data]
    if not all(size == sizes_data[0] for size in sizes_data):
        raise ValueError("All images in a batch must have the same size.")
    
    # Check if all ground truth tensors have the same size
    sizes_target = [gt.size() for gt in target]
    if not all(size == sizes_target[0] for size in sizes_target):
        raise ValueError("All ground truth tensors in a batch must have the same size.")
    
    target = torch.stack(target, 0)
    return [data, target]
    
EPOCH = 40
BATCH_SIZE = 64
print('BATCH_SIZE = ',BATCH_SIZE)
LR = 0.001              # learning rate
# too big: converges fast but misses local minimal. Small: time consuming. Use array: every 20 steps LR *= 0.9. 
NUM_WORKERS = 0

optimizer = torch.optim.Adam(fcn.parameters(), lr=LR)   # optimize all cnn parameters
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.9) # dynamic learning rate
#optimizer = torch.optim.SGD(cnn.parameters(), lr=LR, momentum=0.9)   # optimize all cnn parameters
loss_func = nn.MSELoss()


train_data=MyDataset(dataset=train_set)
train_loader = data_utils.DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)

test_data=MyDataset(dataset=test_set)
test_loader = data_utils.DataLoader(dataset=test_data, batch_size=1)

BATCH_SIZE =  64


In [22]:
from datetime import datetime
dataString = datetime.strftime(datetime.now(), '%Y_%m_%d_%H-%M-%S')

In [23]:
root_result = 'trained_model/'
if not os.path.exists(root_result):
    os.mkdir(root_result)

model_result = os.path.join(root_result, 'model/')
log_result = os.path.join(root_result, 'log/')

if not os.path.exists(model_result):
    os.mkdir(model_result)

if not os.path.exists(log_result):
    os.mkdir(log_result)

In [24]:
from datetime import datetime
import time

fileOut=open(log_result+'log_'+dataString,'a')
fileOut.write(dataString+'Epoch:   Step:    Loss:        Val_Accu :\n')
fileOut.close()
fileOut2 = open(log_result+'validation_'+dataString, 'a')
fileOut2.write('kernal_size of conv_f is 2')
fileOut2.write(dataString+'Epoch:    loss:')
fileOut3 = open(log_result+'train_'+dataString, 'a')
fileOut3.write('kernal_size of conv_f is 2')
fileOut3.write(dataString+'Epoch:    loss:')

start_time = datetime.now()
print("Start Time:", start_time)

with open(log_result + 'elapsed_time_log_' + dataString, 'a') as fileOut:
    fileOut.write(f"Start time: {start_time}\n")

for epoch in range(EPOCH):
    fcn.train()
    total_loss = 0  # Initialize the loss accumulator
    num_batches = 0  # Counter for the number of batches

    for step, (img, gt) in enumerate(train_loader):   # gives batch data, normalize x when iterate train_loader
        img = Variable(img).cpu()
        gt = gt.float()
        gt = Variable(gt).cpu()
        output = fcn(img)  # cnn output
        loss = loss_func(output, gt)  # cross entropy loss
        optimizer.zero_grad()  # clear gradients for this training step
        loss.backward()  # backpropagation, compute gradients
        optimizer.step()  # apply gradients

        total_loss += loss.item()  # Accumulate the loss
        num_batches += 1

        print(f"epoch: {epoch}, batch step: {step}, loss: {loss.data.item()}")
        with open(log_result + 'log_' + dataString, 'a') as fileOut:
            fileOut.write(f"{epoch}   {step}   {loss.data.item()}\n")

        end_time = datetime.now()

        elapsed_time = end_time - start_time
        elapsed_hours, remainder = divmod(elapsed_time.total_seconds(), 3600)
        elapsed_minutes, elapsed_seconds = divmod(remainder, 60)

        # Output total elapsed time in hours:minutes:seconds
        print("{} ({}:{}:{})".format(
            end_time, int(elapsed_hours), int(elapsed_minutes), int(elapsed_seconds)
        ))
        
        with open(log_result + 'elapsed_time_log_' + dataString, 'a') as fileOut:
            fileOut.write(f"{epoch}   {step}   {end_time}   {elapsed_hours}:{elapsed_minutes}:{elapsed_seconds}\n")

    average_loss = total_loss / num_batches  # Calculate average loss for the epoch
    with open(log_result + 'train_' + dataString, 'a') as fileOut:
        fileOut.write(f"Epoch {epoch}: Average Loss: {average_loss}\n")  # Log the average loss per epoch
        print(f"Epoch {epoch}: Average Train Loss: {average_loss}\n")

    if epoch % 10 == 9:
        PATH = model_result + 'model' + dataString + '_' + str(epoch) + '_' + str(step)
        torch.save(fcn.state_dict(), PATH)
        print(f'Finished saving checkpoints for epoch {epoch}')

    print(f'Epoch {epoch}: Average Training Loss: {average_loss}')
    
    scheduler.step()
    print("Epoch {}, Current learning rate: {}".format(epoch, scheduler.get_last_lr()))
     
    LOSS_VALIDATION = 0
    fcn.eval()
    with torch.no_grad():
        for step, (img,gt) in enumerate(test_loader):

            img = Variable(img).cpu()
            # gt=gt.unsqueeze(1)# batch x
            gt=gt.float()
            gt = Variable(gt).cpu()
            # print(f"gt test size:{gt.size()}")
            output = fcn(img) 
            # print(f"validation output size:{output.size()}")
            LOSS_VALIDATION += loss_func(output, gt)
        #print(LOSS_VALIDATION.data.item())
        LOSS_VALIDATION = LOSS_VALIDATION/step
        fileOut2 = open(log_result+'validation_'+dataString, 'a')
        fileOut2.write(str(epoch)+'   '+str(step)+'   '+str(LOSS_VALIDATION.data.item())+'\n')
        fileOut2.close()
        print('validation error epoch  '+str(epoch)+':    '+str(LOSS_VALIDATION)+'\n'+str(step))

Start Time: 2024-07-30 15:49:12.106443
epoch: 0, batch step: 0, loss: 0.27605512738227844
2024-07-30 15:49:44.071385 (0:0:31)
epoch: 0, batch step: 1, loss: 0.7177653908729553
2024-07-30 15:49:56.004633 (0:0:43)
Epoch 0: Average Train Loss: 0.4969102591276169

Epoch 0: Average Training Loss: 0.4969102591276169
Epoch 0, Current learning rate: [0.001]
validation error epoch  0:    tensor(0.3548)
9
epoch: 1, batch step: 0, loss: 0.2839232385158539
2024-07-30 15:50:31.634508 (0:1:19)
epoch: 1, batch step: 1, loss: 0.34490251541137695
2024-07-30 15:50:43.401739 (0:1:31)
Epoch 1: Average Train Loss: 0.3144128769636154

Epoch 1: Average Training Loss: 0.3144128769636154
Epoch 1, Current learning rate: [0.001]
validation error epoch  1:    tensor(1.6309)
9
epoch: 2, batch step: 0, loss: 0.33376118540763855
2024-07-30 15:51:28.134261 (0:2:16)
epoch: 2, batch step: 1, loss: 0.2612926661968231
2024-07-30 15:51:45.885607 (0:2:33)
Epoch 2: Average Train Loss: 0.29752692580223083

Epoch 2: Average T

KeyboardInterrupt: 