In [1]:
import numpy as np
import matplotlib.pyplot as plt
import cv2
import subprocess
import os
import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as f
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchmetrics.image import StructuralSimilarityIndexMeasure
import torchvision.models as models
from unet_model import UNet

ImportError: DLL load failed while importing cv2: The specified module could not be found.

In [None]:
# Check pytorch versions and CPU/GPU availability
print(torch.__version__)
if torch.cuda.is_available():
    print("GPU available")
    print("Num GPUs available: ", torch.cuda.device_count())
else:
    print("No GPU available")

command = "nvcc --version"
result = subprocess.run(command, stdout=subprocess.PIPE, shell=True, text=True)
print(result.stdout)
torch.set_default_dtype(torch.float64)


In [None]:
class CustomDataset(Dataset):
    def __init__(self,data_name,dir_name,num_in,im_dim,cropsize,frames,bg_lvl):
        self.data_name = data_name
        self.dir_name = dir_name
        self.num_in = num_in
        self.im_dim = im_dim
        self.cropsize = cropsize
        self.frames = frames
        self.bg_lvl = bg_lvl
        self.transform = transforms.Compose([transforms.ToTensor()])
        
    def __len__(self):
        return self.num_in
    
    def __getitem__(self,idx):
        input_frames=np.zeros([self.cropsize,self.cropsize,self.frames])
        gt_frames=np.zeros([self.cropsize,self.cropsize,1])
        lr_frames=np.zeros([self.cropsize,self.cropsize,1])
        patterns=np.zeros([self.cropsize,self.cropsize,self.frames])
        center_x=(im_dim-cropsize)//2
        center_y=(im_dim-cropsize)//2
        
        for j in range(1,(self.frames+1)):
            input_path = os.path.join(self.dir_name,self.data_name,'input_frames',str(idx+1)+'_'+str(j)+'.png')
            input_temp = cv2.imread(input_path,0) # type: ignore
            input_temp = cv2.resize(input_temp,dsize=(self.im_dim,self.im_dim), interpolation=cv2.INTER_CUBIC)
            input_temp = input_temp[center_y:center_y+cropsize, center_x:center_x+cropsize]
            input_frames[:,:,j-1] = input_temp
            
        gt_path=os.path.join(self.dir_name,self.data_name,'ground_truth',str(idx+1)+'.png')
        gt_temp = cv2.imread(gt_path,0)
        gt_temp=cv2.resize(gt_temp,dsize=(self.im_dim,self.im_dim), interpolation=cv2.INTER_CUBIC)
        gt_temp = gt_temp.reshape([self.im_dim,self.im_dim,1])
        gt_temp = gt_temp[center_y:center_y+cropsize, center_x:center_x+cropsize]
        gt_frames[:,:,:]=gt_temp
        
        lr_path=os.path.join(self.dir_name,self.data_name,'low_res',str(idx+1)+'.png')
        lr_temp = cv2.imread(lr_path,0)
        lr_temp=cv2.resize(lr_temp,dsize=(self.im_dim,self.im_dim), interpolation=cv2.INTER_CUBIC)
        lr_temp = lr_temp.reshape([self.im_dim,self.im_dim,1])
        lr_temp = lr_temp[center_y:center_y+cropsize, center_x:center_x+cropsize]
        lr_frames[:,:,:]=lr_temp
        
        psf = cv2.imread(dir_name+data_name+'/psf.png',0)
        psf = psf.reshape([np.ma.size(psf,0),np.ma.size(psf,0),1])
        
        for j in range(1,(self.frames+1)):
            pattern_path = os.path.join(self.dir_name,self.data_name,'patterns',str(j)+'.png')
            pattern_temp = cv2.imread(pattern_path,0)
            pattern_temp = cv2.resize(pattern_temp,dsize=(self.im_dim,self.im_dim), interpolation=cv2.INTER_CUBIC) 
            pattern_temp = pattern_temp[center_y:center_y+cropsize, center_x:center_x+cropsize]
            patterns[:,:,j-1] = pattern_temp
        
        sample = {
            'input_frames':self.transform(input_frames),
            'gt_frames':self.transform(gt_frames),
            'lr_frames':self.transform(lr_frames),
            'patterns':self.transform(patterns),
            'psf':self.transform(psf)
        }
        
        return sample
    
# Usage
data_name='microtubules'
dir_name='Data/SIM/'
num_in=50
im_dim=480
cropsize=480
frames=9
bg_lvl=0

custom_dataset=CustomDataset(data_name,dir_name,num_in,im_dim,cropsize,frames,bg_lvl)
data_loader=DataLoader(dataset=custom_dataset,batch_size=1,shuffle=True)

ind = 23
sample=custom_dataset[ind-1]
input_frames=sample['input_frames']
gt_frames=sample['gt_frames']
lr_frames=sample['lr_frames']
patterns=sample['patterns']
psf=sample['psf']

#normalize data
input_frames-=bg_lvl
input_frames[input_frames<0]=0
input_frames=input_frames/torch.max(input_frames.flatten()).item()
gt_frames=gt_frames/torch.max(gt_frames.flatten()).item()
lr_frames=lr_frames/torch.max(lr_frames.flatten()).item()
patterns=patterns/torch.max(patterns.flatten()).item()
print(torch.max(input_frames.flatten()).item())
print(torch.max(gt_frames.flatten()).item())
print(torch.max(lr_frames.flatten()).item())
print(torch.max(patterns.flatten()).item())
print(input_frames.shape)
print(gt_frames.shape)
print(lr_frames.shape)
print(patterns.shape)
print(psf.shape)
# print(patterns.dtype)

In [None]:
input_0=input_frames[0,:,:].numpy()
plt.imshow(input_0)
plt.show()

In [None]:
# class UNet(nn.Module):
#     def __init__(self, k_size=3):
#         super(UNet, self).__init__()
# 
#         # Downward path
#         self.conv1 = nn.Conv2d(9, 32, kernel_size=k_size, padding=1)  # Use padding=1 for 'same' padding
#         self.relu1 = nn.ReLU()
#         self.conv2 = nn.Conv2d(32, 32, kernel_size=k_size, padding=1)
#         self.relu2 = nn.ReLU()
#         self.pool1 = nn.MaxPool2d(2)
# 
#         self.conv4 = nn.Conv2d(32, 64, kernel_size=k_size, padding=1)
#         self.relu4 = nn.ReLU()
#         self.conv5 = nn.Conv2d(64, 64, kernel_size=k_size, padding=1)
#         self.relu5 = nn.ReLU()
#         self.pool2 = nn.MaxPool2d(2)
# 
#         self.conv6 = nn.Conv2d(64, 128, kernel_size=k_size, padding=1)
#         self.relu6 = nn.ReLU()
#         self.conv7 = nn.Conv2d(128, 128, kernel_size=k_size, padding=1)
#         self.relu7 = nn.ReLU()
# 
#         # Upward path
#         self.conv8 = nn.Conv2d(128, 256, kernel_size=k_size, padding=1)
#         self.relu8 = nn.ReLU()
#         self.conv9 = nn.Conv2d(256, 256, kernel_size=k_size, padding=1)
#         self.relu9 = nn.ReLU()
#         self.up1 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
# 
#         self.conv10 = nn.Conv2d(128 + 128, 128, kernel_size=k_size, padding=1)
#         self.relu10 = nn.ReLU()
#         self.conv11 = nn.Conv2d(128, 128, kernel_size=k_size, padding=1)
#         self.relu11 = nn.ReLU()
#         self.up2 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
# 
#         self.conv12 = nn.Conv2d(64 + 64, 64, kernel_size=k_size, padding=1)
#         self.relu12 = nn.ReLU()
#         self.conv13 = nn.Conv2d(64, 64, kernel_size=k_size, padding=1)
#         self.relu13 = nn.ReLU()
#         self.up3 = nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2)
# 
#         self.conv14 = nn.Conv2d(32 + 32, 32, kernel_size=k_size, padding=1)
#         self.relu14 = nn.ReLU()
#         self.conv15 = nn.Conv2d(32, 32, kernel_size=k_size, padding=1)
#         self.relu15 = nn.ReLU()
#         self.conv16 = nn.Conv2d(32, 32, kernel_size=k_size, padding=1)
#         self.relu16 = nn.ReLU()
# 
#         self.output = nn.Conv2d(32, 1, kernel_size=1, padding=1)
# 
#     def forward(self, x):
#         # Downward path
#         x1 = self.conv1(x)
#         x1 = self.relu1(x1)
#         x1 = self.conv2(x1)
#         x1 = self.relu2(x1)
#         x1 = self.conv3(x1)
#         x1 = self.relu3(x1)
#         down1 = self.pool1(x1)
# 
#         x2 = self.conv4(down1)
#         x2 = self.relu4(x2)
#         x2 = self.conv5(x2)
#         x2 = self.relu5(x2)
#         down2 = self.pool2(x2)
# 
#         x3 = self.conv6(down2)
#         x3 = self.relu6(x3)
#         x3 = self.conv7(x3)
#         x3 = self.relu7(x3)
#         down3 = self.pool3(x3)
# 
#         # Upward path
#         x4 = self.conv8(down3)
#         x4 = self.relu8(x4)
#         x4 = self.conv9(x4)
#         x4 = self.relu9(x4)
#         up1 = self.up1(x4)
# 
#         cat1 = torch.cat([x3, up1], dim=1)
#         x5 = self.conv10(cat1)
#         x5 = self.relu10(x5)
#         x5 = self.conv11(x5)
#         x5 = self.relu11(x5)
#         up2 = self.up2(x5)
# 
#         cat2 = torch.cat([x2, up2], dim=1)
#         x6 = self.conv12(cat2)
#         x6 = self.relu12(x6)
#         x6 = self.conv13(x6)
#         x6 = self.relu13(x6)
#         up3 = self.up3(x6)
# 
#         cat3 = torch.cat([x1, up3], dim=1)
#         x7 = self.conv14(cat3)
#         x7 = self.relu14(x7)
#         x7 = self.conv15(x7)
#         x7 = self.relu15(x7)
#         x7 = self.conv16(x7)
#         x7 = self.relu16(x7)
# 
#         output = self.output(x7)
# 
#         return output

In [None]:
# SSIM_loss
def SSIM_Loss(y_true, y_pred):
    ssim_loss = 1 - ssim(y_true, y_pred)
    # print("ssim_loss:",ssim_loss)
    return ssim_loss

# perceptual loss
def perceptual_loss(in1, in2):
    in1=vgg_normalize(in1)
    in2=vgg_normalize(in2)
    features1=vgg19(in1)
    features2=vgg19(in2)
    # print("size features1",features1.element_size() * features1.nelement())
    # print("size features2",features2.element_size() * features2.nelement())
    criterion=torch.nn.MSELoss()
    perc_loss=criterion(features1,features2)
    # print("perc_loss:",perc_loss)
    return perc_loss


In [None]:
def Physics_loss(y_true, y_pred):
    loss=torch.tensor(0.0,dtype=torch.float64).to(cuda)
    # y_pred=torch.cat([y_pred]*frames,dim=1)
    # print("y_pred shape:",y_pred.shape)
    # print("y_true shape:",y_true.shape)
    # print("patterns shape:",patterns.shape)
    for i in range(batch_size):
        prod=y_pred*patterns
        # print("prod shape:",prod.shape)
        conv=torch.zeros_like(prod)
        # print("conv type",conv.dtype)
        padding = (psf.shape[-2] - 1) // 2
        for j in range(frames):
            single_image=prod[:,j:j+1,:,:]
            # print("single image type: ",single_image.dtype)
            # print("psf type: ",psf.dtype)
            convolved_image=f.conv2d(single_image,psf,stride=1,padding=padding)
            # print("convolved image shape:",convolved_image.shape)
            conv[:,j:j+1,:,:]=convolved_image
        conv=(conv-torch.min(conv))/(torch.max(conv)-torch.min(conv))
        # print("conv shape:",conv.shape)
        true=y_true
        # individual_losses=torch.zeros(frames)
        # # finding perceptual loss for each frame
        # for j in range(frames):
        #     y_true_frame=y_true[:,j:j+1,:,:]
        #     y_pred_frame=conv[:,j:j+1,:,:]
        #     y_true_rgb=torch.cat([y_true_frame]*3,dim=1)
        #     y_pred_rgb=torch.cat([y_pred_frame]*3,dim=1)
        #     # print("y_pred_rgb shape:",y_true_rgb.shape)
        #     # print("y_true_rgb shape:",y_pred_rgb.shape)
        #     pl_j_loss=perceptual_loss(y_true_rgb,y_pred_rgb)
        #     # torch.cuda.empty_cache()
        #     individual_losses[j]=pl_j_loss
        # normalized_losses=individual_losses/frames
        # final_pl=torch.sum(normalized_losses)
        # print("final_pl:",final_pl)

        #total loss
        im_loss=SSIM_Loss(conv,true)
        loss+=im_loss
        print("total_loss: ",loss)
        
    return loss

In [None]:
model = UNet(n_channels=9, n_classes=1)
print("{} parameters in total".format(sum(x.numel() for x in model.parameters())))
learning_rate = 0.001

cuda=torch.device('cuda:0')
model=model.to(cuda)

# Create the Adam optimizer with the initial learning rate
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Create an ExponentialLR scheduler to decay the learning rate
lr_schedule = optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.90)

# Set batch size and number of epochs
batch_size = 1
epochs = 50

############################################################################################################
ssim = StructuralSimilarityIndexMeasure(data_range=1.0, reduction='elementwise_mean').to(cuda)
vgg19= models.vgg19(weights='VGG19_Weights.DEFAULT').features.to(cuda).eval()
vgg_normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]).to(cuda)
############################################################################################################

l = Physics_loss

print(model)

In [None]:
# prepare 4D tensors for input from 3D tensors
input_frames=input_frames.unsqueeze(0).to(cuda)
gt_frames=gt_frames.unsqueeze(0).to(cuda)
lr_frames=lr_frames.unsqueeze(0).to(cuda)
patterns=patterns.unsqueeze(0).to(cuda)
psf=psf.unsqueeze(0).to(cuda)
print("input shape:",input_frames.shape)
print("gt shape:",gt_frames.shape)
print("lr shape:",lr_frames.shape)
print("patterns shape:",patterns.shape)
print("psf shape:",psf.shape)

In [None]:
train_losses=[]
torch.backends.cudnn.benchmark = True
for epoch in range(epochs):
    running_loss=0.0
     # zero the parameter gradients
    for param in model.parameters():
        param.grad=None
    outputs=model(input_frames) # forward pass
    loss=l(input_frames,outputs)
    running_loss+=loss
    loss.backward() #backpropagate the loss
    optimizer.step() # update the weights
    lr_schedule.step() # update the learning rate
    train_loss=running_loss/1
    train_losses.append(train_loss)
    #print average training loss for the epoch
    print('Epoch: {} \tTraining Loss: {:.6f}'.format(epoch+1, train_loss))
    
        

In [None]:
plt.plot(range(epochs),train_losses, label='Optimization loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.show()

In [None]:
model.eval()
with torch.no_grad():
    output_tensor = model(input_frames)

recon_image = output_tensor.detach().cpu().numpy()
recon_image = recon_image[0, 0, :, :]  # Assuming a single image in the batch
recon_image = recon_image[20:(im_dim-20), 20:(im_dim-20)]
recon_image = (recon_image - np.amin(recon_image)) / (np.amax(recon_image) - np.amin(recon_image))

# Convert gt_frames and lr_frames to NumPy arrays
gt_frames = gt_frames.detach().cpu().numpy()
lr_frames = lr_frames.detach().cpu().numpy()

gt = gt_frames[0, 0, :, :]  # Assuming a single image in the batch
gt = gt[20:(im_dim-20), 20:(im_dim-20)]
gt = (gt - np.amin(gt)) / (np.amax(gt) - np.amin(gt))
gt = gt.reshape((im_dim-40), (im_dim-40))

lowres = lr_frames[0, 0, :, :]  # Assuming a single image in the batch
lowres = lowres[20:(im_dim-20), 20:(im_dim-20)]
lowres = (lowres - np.amin(lowres)) / (np.amax(lowres) - np.amin(lowres))
lowres = lowres.reshape((im_dim-40), (im_dim-40))

plt.figure(figsize=(20, 10))
plt.subplot(1, 3, 1)
plt.imshow(gt, cmap='inferno')
plt.title('Ground truth image')
plt.subplot(1, 3, 2)
plt.imshow(recon_image, cmap='inferno')
plt.title('PINN Result')
plt.subplot(1, 3, 3)
plt.imshow(lowres, cmap='inferno')
plt.title('Diffraction limited image')
#plt.show()