In [12]:
# registration inference
import torch
import torch.nn as nn
import torch.nn.functional as F
import albumentations as A
from albumentations.pytorch import ToTensorV2
from torch.utils.data import Dataset, DataLoader
import cv2


from glob import glob

In [13]:
# Augmentation with Albumentations
train_transform = A.Compose([
    A.Rotate(limit=25, p=0.8),
    A.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.1, rotate_limit=0, p=0.8),
    A.ElasticTransform(alpha=60, sigma=12, p=0.2),
    A.GridDistortion(num_steps=5, distort_limit=0.3, p=0.2),
    A.OpticalDistortion(distort_limit=0.2, shift_limit=0.05, p=0.2),
    A.Resize(512, 512),  # Rescale all images to a fixed size
    ToTensorV2()
])

val_transform = A.Compose([
    A.Resize(512, 512),
    ToTensorV2()
])

class LungSegmentationDataset(Dataset):
    def __init__(self, images, masks, transform=None):
        self.images = images
        self.masks = masks
        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = cv2.imread(self.images[idx], cv2.IMREAD_GRAYSCALE)
        mask = cv2.imread(self.masks[idx], cv2.IMREAD_GRAYSCALE)
        file_name = self.images[idx].split('/')[-1]
        
        image = self.transform(image=image)['image'][0]
        mask = self.transform(image=mask)['image'][0]
        
        image = image / 255.0
        mask = mask / 255.0
        
        return image, mask, file_name




mask_datalist = glob('/hdd/project/cylce_het/dataset/ralo_mask/*.jpg')
image_datalist = glob("/hdd/project/cylce_het/dataset/ralo_raw_images/*.jpg")

sorted(mask_datalist), sorted(image_datalist)

# Define dataset and dataloader
train_dataset = LungSegmentationDataset(image_datalist, mask_datalist, transform=val_transform)
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)

# Spatial Transformer Network (STN)
class STN(nn.Module):
    def __init__(self, in_shape=(1,512,512), mask_resize: int = 512,
                 dense_neurons=50, freeze_align_model=False):
        super(STN, self).__init__()
        
        assert not in_shape[1] % mask_resize, "The STN size must be a multiple of mask size"
        trainable = not freeze_align_model

        # MaxPooling for input adaptation
        self.pool1 = nn.MaxPool2d(kernel_size=(in_shape[1] // mask_resize, in_shape[2] // mask_resize))

        # Convolutional and pooling layers
        self.pool2 = nn.MaxPool2d(kernel_size=2)
        self.conv1 = nn.Conv2d(in_channels=in_shape[0], out_channels=20, kernel_size=5, stride=1)
        if not trainable:
            self.conv1.requires_grad_(False)

        self.pool3 = nn.MaxPool2d(kernel_size=2)
        self.conv2 = nn.Conv2d(in_channels=20, out_channels=20, kernel_size=5, stride=1)
        if not trainable:
            self.conv2.requires_grad_(False)

        # Flatten layer
        self.flatten = nn.Flatten()

        # Fully connected layers
        self.fc1 = nn.Linear(in_features=self.calculate_flatten_size(in_shape), out_features=dense_neurons)
        if not trainable:
            self.fc1.requires_grad_(False)
        
        self.relu = nn.ReLU()

        # Final alignment layer (6 affine parameters)
        self.fc2 = nn.Linear(in_features=dense_neurons, out_features=6)
        if not trainable:
            self.fc2.requires_grad_(False)
        
        # Initialize the alignment layer to the identity transformation
        self.fc2.weight.data.zero_()
        self.fc2.bias.data.copy_(torch.tensor([1, 0, 0, 0, 1, 0], dtype=torch.float))

    def calculate_flatten_size(self, in_shape):
        # Dummy input to calculate the size after convolution and pooling
        dummy_input = torch.randn(1, in_shape[0], in_shape[1], in_shape[2])
        x = self.pool1(dummy_input)
        x = self.pool2(self.conv1(x))
        x = self.pool3(self.conv2(x))
        return x.numel()

    def forward(self, x):
        # Max Pooling and Convolutional layers
        xs = self.pool1(x)
        xs = self.pool2(self.conv1(xs))
        xs = self.pool3(self.conv2(xs))
        
        # Flatten the feature maps
        xs = self.flatten(xs)

        # Fully connected layers and ReLU
        xs = self.fc1(xs)
        xs = self.relu(xs)
        
        # Output the 6 affine parameters
        theta = self.fc2(xs)
        theta = theta.view(-1, 2, 3)  # Reshape to 2x3 affine matrix
        # Generate affine grid and perform sampling
        grid = F.affine_grid(theta, x.size(), align_corners=False)
        x = F.grid_sample(x, grid, align_corners=False)
        
        return x,grid



# Training loop setup

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
stn_model = STN().to(device)

stn_model.load_state_dict(torch.load('/hdd/project/cylce_het/checkpoint/stn/checkpoint.pth'))

# inference and save all in batch

save_folder = "/hdd/project/cylce_het/dataset/ralo_registered_images/"

stn_model.eval()
with torch.no_grad():
    for idx, (image, mask, file_name) in enumerate(train_loader):
        image = image.to(device,dtype=torch.float)
        mask = mask.to(device,dtype=torch.float)
        image= image.unsqueeze(1)
        mask= mask.unsqueeze(1)
        
        transformed_image, grid = stn_model(mask)
        
        img_res = F.grid_sample(image, grid, align_corners=False)
        
        for i in range(image.size(0)):
            save_path = save_folder + file_name[i]
            # 만약 jpg인 경우 png로 저장
            if save_path.endswith(".jpg"):
                save_path = save_path.replace(".jpg", ".png")
            cv2.imwrite(save_path, img_res[i].squeeze().cpu().numpy()*255)
            print(f"Saved {save_path}")






Using device: cuda
Saved /hdd/project/cylce_het/dataset/ralo_registered_images/809.png
Saved /hdd/project/cylce_het/dataset/ralo_registered_images/1628.png
Saved /hdd/project/cylce_het/dataset/ralo_registered_images/1264.png
Saved /hdd/project/cylce_het/dataset/ralo_registered_images/1009.png
Saved /hdd/project/cylce_het/dataset/ralo_registered_images/2179.png
Saved /hdd/project/cylce_het/dataset/ralo_registered_images/1160.png
Saved /hdd/project/cylce_het/dataset/ralo_registered_images/1437.png
Saved /hdd/project/cylce_het/dataset/ralo_registered_images/2243.png
Saved /hdd/project/cylce_het/dataset/ralo_registered_images/2268.png
Saved /hdd/project/cylce_het/dataset/ralo_registered_images/8.png
Saved /hdd/project/cylce_het/dataset/ralo_registered_images/658.png
Saved /hdd/project/cylce_het/dataset/ralo_registered_images/405.png
Saved /hdd/project/cylce_het/dataset/ralo_registered_images/2225.png
Saved /hdd/project/cylce_het/dataset/ralo_registered_images/2054.png
Saved /hdd/project/cy