In [None]:
import warnings
warnings.filterwarnings(action='ignore')

import os
import pandas as pd
import numpy as np

from PIL import Image
import matplotlib.pyplot as plt

import torch
import torchvision
from torch import Tensor, nn, optim
from torch.utils.data import Dataset, DataLoader, random_split

import torch.nn.functional as F
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.models as models

import tqdm
from torch.optim.adam import Adam

In [None]:
import random
def seed_everything(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)  
    torch.backends.cudnn.deterministic = True  
    torch.backends.cudnn.benchmark = True  
seed_everything()

In [None]:
df = pd.read_csv("/Users/AI/HMH/UTRC/Dataset/1/Dataset.csv") 

In [None]:
df.head()

In [None]:
img_information = df.iloc[0,:]  
print(img_information)

In [None]:
import os
import pandas as pd
from torchvision.io import read_image

class CustomDataset(Dataset):
    def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
        self.img_labels = pd.read_csv(annotations_file)
        self.img_dir = img_dir
        self.transform = transform
        self.target_transform = target_transform

        
    def __len__(self):
        return len(self.img_labels)

    
    def __getitem__(self, idx):
        img_out_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
        img_in_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 1])
        image_in = Image.open(img_in_path)
        image_out = Image.open(img_out_path)
        label = np.array(self.img_labels.iloc[idx, 2:]).astype(np.float32)


        if self.transform:
            image_in = self.transform(image_in)
            image_out = self.transform(image_out)
            
        if self.target_transform:
            label = self.target_transform(label)
        return image_in, image_out, label

In [None]:
transform = transforms.Compose([transforms.Resize((256,256)), transforms.ToTensor()])        
target_transform = transforms.Compose([ transforms.ToTensor()])

dataset = CustomDataset(annotations_file ="/Users/AI/HMH/UTRC/Dataset/1/Dataset.csv",    
                        img_dir = "/Users/AI/HMH/UTRC/Dataset/1/Image/", transform=transform)

dataset_size = len(dataset)
train_size = int(dataset_size * 0.9)
test_size = dataset_size - train_size

train_dataset, test_dataset = random_split(dataset, [train_size, test_size])

print(f"Training Data Size : {len(train_dataset)}")
print(f"Testing Data Size : {len(test_dataset)}")


In [None]:
t1 = train_dataset.__getitem__(4)[1]
print(t1.dtype)
print(t1)
plt.imshow(t1.transpose(2,0))

In [None]:
class UNet(nn.Module):                   
    def __init__(self):
        super(UNet, self).__init__()
        
        # Encoder
        self.enc1_1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
        self.enc1_2 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
        self.pool1 =nn.MaxPool2d(kernel_size=2, stride=2)                  # 256 -> 128     
        
        self.enc2_1 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.enc2_2 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
        self.pool2 =nn.MaxPool2d(kernel_size=2, stride=2)                  # 128 -> 64 
        
        self.enc3_1 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
        self.enc3_2 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
        self.pool3 =nn.MaxPool2d(kernel_size=2, stride=2)                  # 64 -> 32
        
        self.enc4_1 = nn.Conv2d(256, 512, kernel_size=3, padding=1)
        self.enc4_2 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.pool4 =nn.MaxPool2d(kernel_size=2, stride=2)                  # 32 -> 16 
        
        self.enc5_1 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.enc5_2 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.pool5 =nn.MaxPool2d(kernel_size=2, stride=2)                  # 16 -> 8 

        
        
        
        self.fc0 = nn.Linear(in_features=32768, out_features=64 )                 
        self.fc1 = nn.Linear(in_features=64+5, out_features=1024 )
        self.fc2 = nn.Linear(in_features=1024, out_features=1024 )
        self.fc3 = nn.Linear(in_features=1024, out_features=512 )
        
        
        
        # Decoder
        self.upsample5 = nn.ConvTranspose2d(512, 512, 2, stride=2)
        self.dec5_1 = nn.Conv2d(1024, 512, kernel_size=3, padding=1)
        self.dec5_2 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        
        self.upsample4 = nn.ConvTranspose2d(512, 512, 2, stride=2)
        self.dec4_1 = nn.Conv2d(1024, 512, kernel_size=3, padding=1)
        self.dec4_2 = nn.Conv2d(512, 256, kernel_size=3, padding=1)
        
        self.upsample3 = nn.ConvTranspose2d(256, 256, 2, stride=2)
        self.dec3_1 = nn.Conv2d(512, 256, kernel_size=3, padding=1)
        self.dec3_2 = nn.Conv2d(256, 128, kernel_size=3, padding=1)
        
        self.upsample2 = nn.ConvTranspose2d(128, 128, 2, stride=2)
        self.dec2_1 = nn.Conv2d(256, 128, kernel_size=3, padding=1)
        self.dec2_2 = nn.Conv2d(128, 64, kernel_size=3, padding=1)
        
        self.upsample1 = nn.ConvTranspose2d(64, 64, 2, stride=2)
        self.dec1_1 = nn.Conv2d(128, 64, kernel_size=3, padding=1)
        self.dec1_2 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
        self.dec1_3 = nn.Conv2d(64, 3, kernel_size=1)
        
        
        
        self.relu = nn.ReLU()
        
        
        
    def forward(self,x ,y):
        
        # Encoder Network
        x = self.enc1_1(x)
        x = self.relu(x)
        e1 = self.enc1_2(x)
        e1 = self.relu(e1)         # 256
        x = self.pool1(e1)         # 256 --> 128
        
        x = self.enc2_1(x)
        x = self.relu(x)
        e2 = self.enc2_2(x)
        e2 = self.relu(e2)         # 128 --> 128
        x = self.pool2(e2)         # 128 --> 64
        
        x = self.enc3_1(x)
        x = self.relu(x)
        e3 = self.enc3_2(x)
        e3 = self.relu(e3)         # 64 --> 64
        x = self.pool3(e3)         # 64 --> 32
        
        x = self.enc4_1(x)
        x = self.relu(x)
        e4 = self.enc4_2(x)
        e4 = self.relu(e4)         # 32 --> 32
        x = self.pool4(e4)         # 32 --> 16
        
        x = self.enc5_1(x)
        x = self.relu(x)
        e5 = self.enc5_2(x)
        e5 = self.relu(e5)         # 16 --> 16
        x = self.pool5(e5)         # 16 --> 8
        
        
        
        flatten_x = torch.flatten(x, start_dim=1)  
        flatten_x = self.fc0(flatten_x)
        flatten_x = self.relu(flatten_x)
        
        
        # Condition
        y = torch.flatten(y, start_dim = 1)
        y = torch.cat([flatten_x, y], dim=1)
        y = self.fc1(y)
        y = self.relu(y)
        y = self.fc2(y)
        y = self.relu(y)
        y = self.fc3(y)
        y = self.relu(y)

        
        
        cond = y.repeat(1, x.shape[2]*x.shape[2])
        cond = torch.reshape(cond, (-1,512, x.shape[2], x.shape[2]))

        
        # Decoder Network
        x = self.upsample5(cond)
        x = torch.cat([x, e5], dim=1)         # (512,16,16) + (512,16,16) --> (1024,16,16)
        x = self.dec5_1(x)
        x = self.relu(x)
        x = self.dec5_2(x)
        x = self.relu(x)
        
        x = self.upsample4(x)
        x = torch.cat([x, e4], dim=1)
        x = self.dec4_1(x)
        x = self.relu(x)
        x = self.dec4_2(x)
        x = self.relu(x)
        
        x = self.upsample3(x)
        x = torch.cat([x, e3], dim=1)
        x = self.dec3_1(x)
        x = self.relu(x)
        x = self.dec3_2(x)
        x = self.relu(x)
        
        x = self.upsample2(x)
        x = torch.cat([x, e2], dim=1)
        x = self.dec2_1(x)
        x = self.relu(x)
        x = self.dec2_2(x)
        x = self.relu(x)
        
        x = self.upsample1(x)
        x = torch.cat([x, e1], dim=1)
        x = self.dec1_1(x)
        x = self.relu(x)
        x = self.dec1_2(x)
        x = self.relu(x)
        x = self.dec1_3(x)
        
        return x

In [None]:
from torchsummary import summary
my_model = UNet().cpu()
summary(my_model, [(3,256,256),(5,1,1)], device='cpu')

In [None]:
if torch.cuda.is_available():  
  dev = "cuda:0" 
  print("gpu up")
else:  
  dev = "cpu"  
device = torch.device(dev)

In [None]:
# Model
model = UNet().to(device)


# Load the Dataset
train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=16, shuffle=False)



optim = Adam(params=model.parameters(), lr=0.001)

In [None]:
# test dataset
test_in, test_out, cond = next(iter(test_dataloader))

case_control_num = 15

test_in = test_in[case_control_num]
test_in = torch.unsqueeze(test_in, dim=0)
test_out = test_out[case_control_num]
test_out = torch.unsqueeze(test_out, dim=0)
cond = cond[case_control_num]
cond = torch.unsqueeze(cond, dim=0)

print({test_in.size()})
print({test_out.size()})
print({cond.size()})
print(cond)

In [None]:
# Training
for epoch in range(300):
    iterator = tqdm.tqdm(train_dataloader)
    
    for data, label, cond in iterator:                 
        optim.zero_grad()
        
        pred = model(data.to(device), cond.to(device))  
        
        loss = nn.MSELoss()(pred, label.to(device))    
        loss.backward()                                
        optim.step()                                   
        
        iterator.set_description(f"epoch:{epoch+1} loss:{loss.item()}")
        
torch.save(model.state_dict(), "./UNet.pth")           


In [None]:
with torch.no_grad():
    
    model.load_state_dict(torch .load("UNet_paper Net_nomalized(256_256)(512,8,8)_OK_231017.pth", map_location = "cpu"))   
    
    
    pred_image = model(test_in.to(device), cond.to(device))
    pred_image = torch.squeeze(pred_image, dim=0)
    pred_image = pred_image.transpose(2,0)
    
    test_out = torch.squeeze(test_out, dim=0)
    test_out = test_out.transpose(2,0)  
    
print(pred_image.size())
print(test_out.size())

In [None]:
plt.subplot(1,2,1)
plt.imshow(test_out.cpu())
plt.title(f"real image : {cond}")

plt.subplot(1,2,2)
plt.imshow(pred_image.cpu())
plt.title(f"predicted image : {cond}")

In [None]:


plt.figure(figsize=(16,16))

plt.subplot(2,2,1)
cond = [4.49, 2.57, 2.95, 1.44, 1.9]                            
cond = torch.tensor(cond)
cond = torch.unsqueeze(cond, dim=0)
img = model(test_in.to(device), cond.to(device))
img = torch.squeeze(img, dim=0)
img = img.cpu().transpose(2,0)
img = img.detach().numpy()
plt.title(cond[0:5])
plt.imshow(img)


plt.subplot(2,2,2)
cond = [4.48, 2.4, 2.4, 1.2, 1.2]                             
cond = torch.tensor(cond)
cond = torch.unsqueeze(cond, dim=0)
img = model(test_in.to(device), cond.to(device))
img = torch.squeeze(img, dim=0)
img = img.cpu().transpose(2,0)
img = img.detach().numpy()
plt.title(cond[0:5])
plt.imshow(img)


plt.subplot(2,2,3)
cond = [6, 2.4, 2.4, 2.4, 2.4]                             
cond = torch.tensor(cond)
cond = torch.unsqueeze(cond, dim=0)
img = model(test_in.to(device), cond.to(device))
img = torch.squeeze(img, dim=0)
img = img.cpu().transpose(2,0)
img = img.detach().numpy()
plt.title(cond[0:5])
plt.imshow(img)


plt.subplot(2,2,4)
cond = [6.72, 2.4, 2.4, 2.4, 2.4]                              
cond = torch.tensor(cond)
cond = torch.unsqueeze(cond, dim=0)
img = model(test_in.to(device), cond.to(device))
img = torch.squeeze(img, dim=0)
img = img.cpu().transpose(2,0)
img = img.detach().numpy()
plt.title(cond[0:5])
plt.imshow(img)