In [None]:
from google.colab import drive
drive.mount('/content/drive')

import os

Mounted at /content/drive


In [7]:
import numpy as np
import cv2

# Function to create an inverted binary mask
def create_mask(image):
    # Example: Convert image to grayscale and threshold it
    gray_image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
    _, binary_mask = cv2.threshold(gray_image, 127, 255, cv2.THRESH_TOZERO_INV)

    # Invert the binary mask
    inverted_mask = 255 - binary_mask
    return inverted_mask

# Set directory paths
input_dir = '/content/drive/My Drive/Research/Train Images/'
output_dir = '/content/drive/My Drive/Research/Train Masks/'

# Create output directory if it doesn't exist
os.makedirs(output_dir, exist_ok=True)

# Process each image in the input directory
for filename in os.listdir(input_dir):
    if filename.endswith('.JPG') or filename.endswith('.png'):  # Adjust as needed
        # Load the image
        image_path = os.path.join(input_dir, filename)
        image = cv2.imread(image_path)

        # Create the mask
        mask = create_mask(image)

        # Save the mask as an image with a different extension if needed
        mask_filename = os.path.splitext(filename)[0] + '_mask.png'  # Change extension as needed
        mask_path = os.path.join(output_dir, mask_filename)

        # Save the inverted mask as a grayscale image
        cv2.imwrite(mask_path, mask)  # Save mask directly as it is already inverted grayscale

In [None]:
import numpy as np
import cv2
import torch
import torch.nn as nn
import torch.nn.functional as F

class UNet(nn.Module):
        def __init__(self, in_channels=3, out_channels=1):
          super(UNet, self).__init__()

          # Downsampling path
          self.encoder1 = self.conv_block(in_channels, 64)
          self.encoder2 = self.conv_block(64, 128)
          self.encoder3 = self.conv_block(128, 256)
          self.encoder4 = self.conv_block(256, 512)

          # Bottleneck
          self.bottleneck = self.conv_block(512, 1024)

          # Upsampling path
          self.decoder4 = self.upconv_block(1024, 512)
          self.decoder3 = self.upconv_block(512, 256)
          self.decoder2 = self.upconv_block(256, 128)
          self.decoder1 = self.upconv_block(128, 64)

          # Final convolution
          self.final_conv = nn.Conv2d(64, out_channels, kernel_size=1)

        def conv_block(self, in_channels, out_channels):
          return nn.Sequential(
              nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
              nn.BatchNorm2d(out_channels),
              nn.ReLU(inplace=True),
              nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
              nn.BatchNorm2d(out_channels),
              nn.ReLU(inplace=True),
          )

        def upconv_block(self, in_channels, out_channels):
          return nn.Sequential(
              nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2),
              nn.BatchNorm2d(out_channels),
              nn.ReLU(inplace=True),
          )

        def forward(self, x):
          # Encoder
          enc1 = self.encoder1(x)
          enc2 = self.encoder2(F.max_pool2d(enc1, kernel_size=2))
          enc3 = self.encoder3(F.max_pool2d(enc2, kernel_size=2))
          enc4 = self.encoder4(F.max_pool2d(enc3, kernel_size=2))

          # Bottleneck
          bottleneck = self.bottleneck(F.max_pool2d(enc4, kernel_size=2))

          # Decoder
          dec4 = self.decoder4(bottleneck)
          dec4 = torch.cat((dec4, enc4), dim=1)  # Skip connection
          dec4 = self.conv_block(dec4.size(1), 512)(dec4)

          dec3 = self.decoder3(dec4)
          dec3 = torch.cat((dec3, enc3), dim=1)  # Skip connection
          dec3 = self.conv_block(dec3.size(1), 256)(dec3)

          dec2 = self.decoder2(dec3)
          dec2 = torch.cat((dec2, enc2), dim=1)  # Skip connection
          dec2 = self.conv_block(dec2.size(1), 128)(dec2)

          dec1 = self.decoder1(dec2)
          dec1 = torch.cat((dec1, enc1), dim=1)  # Skip connection
          dec1 = self.conv_block(dec1.size(1), 64)(dec1)

          # Final output
          return self.final_conv(dec1)

In [None]:
from torch.utils.data import Dataset, DataLoader

class ImageMaskDataset(Dataset):
    def __init__(self, image_dir, mask_dir, transform=None):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.image_filenames = os.listdir(image_dir)
        self.transform = transform

    def __len__(self):
        return len(self.image_filenames)

    def __getitem__(self, idx):
        img_name = os.path.join(self.image_dir, self.image_filenames[idx])
        mask_name = os.path.join(self.mask_dir, self.image_filenames[idx].replace('.jpg', '_mask.png').replace('.JPG', '_mask.png'))

        image = cv2.imread(img_name)
        mask = cv2.imread(mask_name, cv2.IMREAD_GRAYSCALE)

        image = cv2.resize(image, (256, 256))
        mask = cv2.resize(mask, (256, 256))

        if mask is None:
            raise ValueError(f"Failed to load mask: {mask_name}")

        # Normalize image to [0, 1]
        image = image.astype(np.float32) / 255.0

        mask = mask.astype(np.float32) / 255.0

        # Convert to PyTorch tensors and permute to (C, H, W)
        image = torch.from_numpy(image).permute(2, 0, 1)  # Change channel order
        mask = torch.from_numpy(mask) #.unsqueeze(0)

        if self.transform:
            image = self.transform(image)
            mask = self.transform(mask)

        return image, mask

In [None]:
import matplotlib.pyplot as plt
#Replaced with item in QuickCropper.
def extract_entities(model, image_dir, output_dir):
    model.eval()
    with torch.no_grad():
        for img_name in os.listdir(image_dir):
            img_path = os.path.join(image_dir, img_name)
            image = cv2.imread(img_path)
            image = image.astype(np.float32) / 255.0
            image_tensor = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(0)  # Convert to [N, C, H, W]

            output = model(image_tensor)
            output = torch.sigmoid(output).squeeze().numpy()  # Get probabilities

            # Threshold to create binary mask
            binary_mask = (output > 0.5).astype(np.uint8) * 255

            # Save the mask
            mask_output_path = os.path.join(output_dir, img_name.replace('.jpg', '_mask.png'))
            cv2.imwrite(mask_output_path, binary_mask)

In [None]:
    # Directories
    image_dir = '/content/drive/My Drive/Research/Mussel Images/'
    mask_dir = '/content/drive/My Drive/Research/Mussel Masks/'

    # Create Dataset and DataLoader
    dataset = ImageMaskDataset(image_dir, mask_dir)
    dataloader = DataLoader(dataset, batch_size=2, shuffle=True)

    # Initialize model, criterion, optimizer
    model = UNet(in_channels=3, out_channels=1)
    criterion = nn.BCEWithLogitsLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

    # Train the model
    for epoch in range(15):
        model.train()
        running_loss = 0.0
        for images, masks in dataloader:
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, masks.unsqueeze(1))  # Add channel dimension
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        print(f"Epoch [{epoch+1}/{15}], Loss: {running_loss/len(dataloader):.4f}")

    # Save the trained model
    torch.save(model.state_dict(), 'unet_model.pth')

KeyboardInterrupt: 