In [12]:
import os
from tifffile import TiffFile
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torchvision.models import vgg16_bn
import torch.optim as optim
from tqdm import tqdm
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from torchsummary import summary


In [2]:
data_dir = '../platelet_data'

with TiffFile(os.path.join(data_dir, 'train-images.tif')) as tif:
    train_img = tif.asarray()
    
with TiffFile(os.path.join(data_dir, 'train-labels.tif')) as tif:
    train_label = tif.asarray()
    
with TiffFile(os.path.join(data_dir, 'eval-images.tif')) as tif:
    eval_img = tif.asarray()
    
with TiffFile(os.path.join(data_dir, 'eval-labels.tif')) as tif:
    eval_label = tif.asarray()
    
with TiffFile(os.path.join(data_dir, 'test-images.tif')) as tif:
    test_img = tif.asarray()
    
with TiffFile(os.path.join(data_dir, 'test-labels.tif')) as tif:
    test_label = tif.asarray()

In [3]:
M = max(np.max(train_img), np.max(eval_img), np.max(test_img))
m = min(np.min(train_img), np.min(eval_img), np.min(test_img))
train_img_norm = (train_img - m) / (M - m)
eval_img_norm = (eval_img - m) / (M - m)

In [54]:
label = train_label[0]
print(label.shape)
label

(800, 800)


array([[0, 0, 0, ..., 1, 1, 1],
       [0, 0, 0, ..., 1, 1, 1],
       [0, 0, 0, ..., 1, 1, 1],
       ...,
       [0, 0, 0, ..., 1, 1, 1],
       [0, 0, 0, ..., 1, 1, 1],
       [0, 0, 0, ..., 1, 1, 1]], dtype=uint16)

In [55]:
label = torch.from_numpy(label.astype(np.int16))
print(label.shape)
label

torch.Size([800, 800])


tensor([[0, 0, 0,  ..., 1, 1, 1],
        [0, 0, 0,  ..., 1, 1, 1],
        [0, 0, 0,  ..., 1, 1, 1],
        ...,
        [0, 0, 0,  ..., 1, 1, 1],
        [0, 0, 0,  ..., 1, 1, 1],
        [0, 0, 0,  ..., 1, 1, 1]], dtype=torch.int16)

In [56]:
label = label.long()
print(label.shape)
label

torch.Size([800, 800])


tensor([[0, 0, 0,  ..., 1, 1, 1],
        [0, 0, 0,  ..., 1, 1, 1],
        [0, 0, 0,  ..., 1, 1, 1],
        ...,
        [0, 0, 0,  ..., 1, 1, 1],
        [0, 0, 0,  ..., 1, 1, 1],
        [0, 0, 0,  ..., 1, 1, 1]])

In [57]:
label = torch.nn.functional.one_hot(label, num_classes=7)
print(label.shape)
label

torch.Size([800, 800, 7])


tensor([[[1, 0, 0,  ..., 0, 0, 0],
         [1, 0, 0,  ..., 0, 0, 0],
         [1, 0, 0,  ..., 0, 0, 0],
         ...,
         [0, 1, 0,  ..., 0, 0, 0],
         [0, 1, 0,  ..., 0, 0, 0],
         [0, 1, 0,  ..., 0, 0, 0]],

        [[1, 0, 0,  ..., 0, 0, 0],
         [1, 0, 0,  ..., 0, 0, 0],
         [1, 0, 0,  ..., 0, 0, 0],
         ...,
         [0, 1, 0,  ..., 0, 0, 0],
         [0, 1, 0,  ..., 0, 0, 0],
         [0, 1, 0,  ..., 0, 0, 0]],

        [[1, 0, 0,  ..., 0, 0, 0],
         [1, 0, 0,  ..., 0, 0, 0],
         [1, 0, 0,  ..., 0, 0, 0],
         ...,
         [0, 1, 0,  ..., 0, 0, 0],
         [0, 1, 0,  ..., 0, 0, 0],
         [0, 1, 0,  ..., 0, 0, 0]],

        ...,

        [[1, 0, 0,  ..., 0, 0, 0],
         [1, 0, 0,  ..., 0, 0, 0],
         [1, 0, 0,  ..., 0, 0, 0],
         ...,
         [0, 1, 0,  ..., 0, 0, 0],
         [0, 1, 0,  ..., 0, 0, 0],
         [0, 1, 0,  ..., 0, 0, 0]],

        [[1, 0, 0,  ..., 0, 0, 0],
         [1, 0, 0,  ..., 0, 0, 0],
         [1,

In [58]:
label = label.to(torch.float32)
label = label.permute(2, 0, 1)
print(label.shape)
label[2]

torch.Size([7, 800, 800])


tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])

In [84]:
image = train_img_norm[0]
print(image.shape)
image

(800, 800)


array([[0.29224105, 0.28794763, 0.27185214, ..., 0.16621101, 0.15968425,
        0.18733154],
       [0.29041201, 0.2864459 , 0.28919908, ..., 0.14849827, 0.14524451,
        0.15616095],
       [0.29035425, 0.28727378, 0.27456681, ..., 0.17294956, 0.17814786,
        0.17216018],
       ...,
       [0.21216789, 0.22475934, 0.21274548, ..., 0.10785522, 0.12606854,
        0.12757027],
       [0.17462457, 0.18203697, 0.19162495, ..., 0.11160955, 0.1406238 ,
        0.14597613],
       [0.19928764, 0.18492491, 0.16105121, ..., 0.12724297, 0.11033885,
        0.10375433]])

In [87]:
transforms.ToTensor()(image).shape
transforms.ToTensor()(image)
transforms.ToTensor()(image).permute(1, 2, 0).shape

torch.Size([800, 800, 1])

In [4]:
# Custom Dataset class
class CustomDataset(Dataset):
    def __init__(self, images, labels, transform=None):
        self.images = images
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = self.images[idx]
        label = self.labels[idx]
        label = torch.from_numpy(label.astype(np.int64)).clone()
        label = torch.nn.functional.one_hot(label.long(), num_classes=7)
        label = label.to(torch.float32)
        label = label.permute(2, 0, 1)

        if self.transform:
            image = self.transform(image)

        return image, label

# Transformations (if needed)
transform = transforms.Compose([
    transforms.ToTensor(), # Convert numpy array to PyTorch Tensor
    # Add additional transformations here if required
])


# Creating the datasets
train_dataset = CustomDataset(train_img_norm, train_label, transform=transform)
eval_dataset = CustomDataset(eval_img_norm, eval_label, transform=transform)

# DataLoader creation
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
eval_loader = DataLoader(eval_dataset, batch_size=2, shuffle=False)

In [96]:
train_dataset[10][0]

tensor([[[0.2792, 0.2540, 0.2778,  ..., 0.0720, 0.0688, 0.0777],
         [0.2843, 0.2874, 0.2796,  ..., 0.0576, 0.0887, 0.0880],
         [0.2452, 0.2731, 0.2922,  ..., 0.0762, 0.0782, 0.1064],
         ...,
         [0.1566, 0.1436, 0.1157,  ..., 0.0972, 0.1052, 0.1111],
         [0.1813, 0.1771, 0.1658,  ..., 0.1242, 0.1452, 0.1126],
         [0.1536, 0.1423, 0.1382,  ..., 0.1231, 0.1394, 0.1448]]],
       dtype=torch.float64)

In [5]:
class TwoConvBlock(nn.Module):
    def __init__(self, in_channels, middle_channels, out_channels):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, middle_channels, kernel_size = 3, padding="same")
        self.bn1 = nn.BatchNorm2d(middle_channels)
        self.rl = nn.ReLU()
        self.conv2 = nn.Conv2d(middle_channels, out_channels, kernel_size = 3, padding="same")
        self.bn2 = nn.BatchNorm2d(out_channels)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.rl(x)
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.rl(x)
        return x

class UpConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
        self.bn1 = nn.BatchNorm2d(in_channels)
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size = 2, padding="same")
        self.bn2 = nn.BatchNorm2d(out_channels)

    def forward(self, x):
        x = self.up(x)
        x = self.bn1(x)
        x = self.conv(x)
        x = self.bn2(x)
        return x

class UNet_2D(nn.Module):
    def __init__(self):
        super().__init__()
        self.TCB1 = TwoConvBlock(1, 64, 64)
        self.TCB2 = TwoConvBlock(64, 128, 128)
        self.TCB3 = TwoConvBlock(128, 256, 256)
        self.TCB4 = TwoConvBlock(256, 512, 512)
        self.TCB5 = TwoConvBlock(512, 1024, 1024)
        self.TCB6 = TwoConvBlock(1024, 512, 512)
        self.TCB7 = TwoConvBlock(512, 256, 256)
        self.TCB8 = TwoConvBlock(256, 128, 128)
        self.TCB9 = TwoConvBlock(128, 64, 64)
        self.maxpool = nn.MaxPool2d(2, stride = 2)
        
        self.UC1 = UpConv(1024, 512) 
        self.UC2 = UpConv(512, 256) 
        self.UC3 = UpConv(256, 128) 
        self.UC4= UpConv(128, 64)

        self.conv1 = nn.Conv2d(64, 7, kernel_size = 1)
        self.soft = nn.Softmax(dim = 1)

    def forward(self, x):
        x = self.TCB1(x)
        x1 = x
        x = self.maxpool(x)

        x = self.TCB2(x)
        x2 = x
        x = self.maxpool(x)

        x = self.TCB3(x)
        x3 = x
        x = self.maxpool(x)

        x = self.TCB4(x)
        x4 = x
        x = self.maxpool(x)

        x = self.TCB5(x)

        x = self.UC1(x)
        x = torch.cat([x4, x], dim = 1)
        x = self.TCB6(x)

        x = self.UC2(x)
        x = torch.cat([x3, x], dim = 1)
        x = self.TCB7(x)

        x = self.UC3(x)
        x = torch.cat([x2, x], dim = 1)
        x = self.TCB8(x)

        x = self.UC4(x)
        x = torch.cat([x1, x], dim = 1)
        x = self.TCB9(x)

        x = self.conv1(x)

        return x


In [14]:
# device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
unet = UNet_2D().to(device)
optimizer = optim.Adam(unet.parameters(), lr=0.001)
print(device)

cpu


In [15]:
summary(unet, (1, 800, 800))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 800, 800]             640
       BatchNorm2d-2         [-1, 64, 800, 800]             128
              ReLU-3         [-1, 64, 800, 800]               0
            Conv2d-4         [-1, 64, 800, 800]          36,928
       BatchNorm2d-5         [-1, 64, 800, 800]             128
              ReLU-6         [-1, 64, 800, 800]               0
      TwoConvBlock-7         [-1, 64, 800, 800]               0
         MaxPool2d-8         [-1, 64, 400, 400]               0
            Conv2d-9        [-1, 128, 400, 400]          73,856
      BatchNorm2d-10        [-1, 128, 400, 400]             256
             ReLU-11        [-1, 128, 400, 400]               0
           Conv2d-12        [-1, 128, 400, 400]         147,584
      BatchNorm2d-13        [-1, 128, 400, 400]             256
             ReLU-14        [-1, 128, 4

In [7]:
def criterion(output, target):
    return nn.BCEWithLogitsLoss()(output, target)

In [82]:
train_loader

<torch.utils.data.dataloader.DataLoader at 0x16c827b80>

In [9]:
history = {"train_loss": []}
n = 0
m = 0

for epoch in range(15):
  train_loss = 0
  val_loss = 0

  unet.train()
  for i, data in enumerate(train_loader):
    inputs, labels = data
    print(inputs.shape)
    print(labels.shape)
    inputs = inputs.float().to(device)
    labels = labels.float().to(device)
    print(inputs.dtype)
    optimizer.zero_grad()
    outputs = unet(inputs)
    loss = criterion(outputs, labels)
    loss.backward()
    optimizer.step()
    train_loss += loss.item()
    history["train_loss"].append(loss.item())
    n += 1
    # if i % ((len(df)//BATCH_SIZE)//10) == (len(df)//BATCH_SIZE)//10 - 1:
    print(f"epoch:{epoch+1}  index:{i+1}  train_loss:{train_loss/n:.5f}")
    n = 0
    train_loss = 0
    train_acc = 0


  unet.eval()
  with torch.no_grad():
    for i, data in enumerate(eval_loader):
      inputs, labels = data
      inputs = inputs.float().to(device)
      labels = labels.float().to(device)
      outputs = unet(inputs)
      loss = criterion(outputs, labels)
      val_loss += loss.item()
      m += 1
      # if i % (len(val_df)//BATCH_SIZE) == len(val_df)//BATCH_SIZE - 1:
      print(f"epoch:{epoch+1}  index:{i+1}  val_loss:{val_loss/m:.5f}")
      m = 0
      val_loss = 0
      val_acc = 0

  # torch.save(unet.state_dict(), f"./train_{epoch+1}.pth")
print("finish training")


torch.Size([4, 1, 800, 800])
torch.Size([4, 7, 800, 800])
torch.float32


  return F.conv2d(input, weight, bias, self.stride,


RuntimeError: MPS backend out of memory (MPS allocated: 14.07 GB, other allocations: 4.14 GB, max allowed: 18.13 GB). Tried to allocate 1.22 GB on private pool. Use PYTORCH_MPS_HIGH_WATERMARK_RATIO=0.0 to disable upper limit for memory allocations (may cause system failure).

In [16]:
class UNet_2D(nn.Module):
    def __init__(self):
        super().__init__()
        self.TCB1 = TwoConvBlock(1, 32, 32)
        self.TCB2 = TwoConvBlock(32, 64, 64)
        self.TCB3 = TwoConvBlock(64, 128, 128)
        self.TCB4 = TwoConvBlock(128, 256, 256)
        self.TCB5 = TwoConvBlock(256, 512, 512)
        self.TCB6 = TwoConvBlock(512, 256, 256)
        self.TCB7 = TwoConvBlock(256, 128, 128)
        self.TCB8 = TwoConvBlock(128, 64, 64)
        self.TCB9 = TwoConvBlock(64, 32, 32)
        self.maxpool = nn.MaxPool2d(2, stride = 2)
        
        self.UC1 = UpConv(512, 256) 
        self.UC2 = UpConv(256, 128) 
        self.UC3 = UpConv(128, 64) 
        self.UC4= UpConv(64, 32)

        self.conv1 = nn.Conv2d(32, 7, kernel_size = 1)
        self.soft = nn.Softmax(dim = 1)

    def forward(self, x):
        x = self.TCB1(x)
        x1 = x
        x = self.maxpool(x)

        x = self.TCB2(x)
        x2 = x
        x = self.maxpool(x)

        x = self.TCB3(x)
        x3 = x
        x = self.maxpool(x)

        x = self.TCB4(x)
        x4 = x
        x = self.maxpool(x)

        x = self.TCB5(x)

        x = self.UC1(x)
        x = torch.cat([x4, x], dim = 1)
        x = self.TCB6(x)

        x = self.UC2(x)
        x = torch.cat([x3, x], dim = 1)
        x = self.TCB7(x)

        x = self.UC3(x)
        x = torch.cat([x2, x], dim = 1)
        x = self.TCB8(x)

        x = self.UC4(x)
        x = torch.cat([x1, x], dim = 1)
        x = self.TCB9(x)

        x = self.conv1(x)

        return x

In [22]:
# device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
unet = UNet_2D().to(device)
optimizer = optim.Adam(unet.parameters(), lr=0.001)
print(device)

cpu


In [23]:
summary(unet, (1, 800, 800))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 32, 800, 800]             320
       BatchNorm2d-2         [-1, 32, 800, 800]              64
              ReLU-3         [-1, 32, 800, 800]               0
            Conv2d-4         [-1, 32, 800, 800]           9,248
       BatchNorm2d-5         [-1, 32, 800, 800]              64
              ReLU-6         [-1, 32, 800, 800]               0
      TwoConvBlock-7         [-1, 32, 800, 800]               0
         MaxPool2d-8         [-1, 32, 400, 400]               0
            Conv2d-9         [-1, 64, 400, 400]          18,496
      BatchNorm2d-10         [-1, 64, 400, 400]             128
             ReLU-11         [-1, 64, 400, 400]               0
           Conv2d-12         [-1, 64, 400, 400]          36,928
      BatchNorm2d-13         [-1, 64, 400, 400]             128
             ReLU-14         [-1, 64, 4

In [24]:
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
unet = UNet_2D().to(device)
optimizer = optim.Adam(unet.parameters(), lr=0.001)
print(device)

mps


In [25]:
history = {"train_loss": []}
n = 0
m = 0

for epoch in range(15):
  train_loss = 0
  val_loss = 0

  unet.train()
  for i, data in enumerate(train_loader):
    inputs, labels = data
    print(inputs.shape)
    print(labels.shape)
    inputs = inputs.float().to(device)
    labels = labels.float().to(device)
    print(inputs.dtype)
    optimizer.zero_grad()
    outputs = unet(inputs)
    loss = criterion(outputs, labels)
    loss.backward()
    optimizer.step()
    train_loss += loss.item()
    history["train_loss"].append(loss.item())
    n += 1
    # if i % ((len(df)//BATCH_SIZE)//10) == (len(df)//BATCH_SIZE)//10 - 1:
    print(f"epoch:{epoch+1}  index:{i+1}  train_loss:{train_loss/n:.5f}")
    n = 0
    train_loss = 0
    train_acc = 0


  unet.eval()
  with torch.no_grad():
    for i, data in enumerate(eval_loader):
      inputs, labels = data
      inputs = inputs.float().to(device)
      labels = labels.float().to(device)
      outputs = unet(inputs)
      loss = criterion(outputs, labels)
      val_loss += loss.item()
      m += 1
      # if i % (len(val_df)//BATCH_SIZE) == len(val_df)//BATCH_SIZE - 1:
      print(f"epoch:{epoch+1}  index:{i+1}  val_loss:{val_loss/m:.5f}")
      m = 0
      val_loss = 0
      val_acc = 0

  # torch.save(unet.state_dict(), f"./train_{epoch+1}.pth")
print("finish training")


torch.Size([4, 1, 800, 800])
torch.Size([4, 7, 800, 800])
torch.float32


RuntimeError: MPS backend out of memory (MPS allocated: 17.02 GB, other allocations: 1.12 GB, max allowed: 18.13 GB). Tried to allocate 1024 bytes on private pool. Use PYTORCH_MPS_HIGH_WATERMARK_RATIO=0.0 to disable upper limit for memory allocations (may cause system failure).