In [2]:
#!pip install opencv-python

In [3]:
#!pip install matplotlib

In [4]:
import os
import shutil
from sqlalchemy import create_engine, event, Table
from sqlalchemy.orm import sessionmaker

import shapely.wkb
import numpy as np
import cv2

from sqlalchemy import inspect


from quickannotator.db import db_session, Project, Image, AnnotationClass, Notification, Tile, Setting, Annotation, SearchCache, build_annotation_table_name

In [5]:
db_path = "sqlite:////opt/QuickAnnotator/quickannotator/instance/quickannotator.db"
engine = create_engine(db_path)#,echo=True)

# Initialize Spatialite extension
@event.listens_for(engine, "connect")
def connect(dbapi_connection, connection_record):
    dbapi_connection.enable_load_extension(True)
    dbapi_connection.execute('SELECT load_extension("mod_spatialite")')
    dbapi_connection.execute('SELECT InitSpatialMetaData(1);')



In [None]:
models = [Image, AnnotationClass, Tile]
db.metadata.create_all(bind=engine, tables=[item.__table__ for item in models])

In [7]:
Session = sessionmaker(bind=engine)
session = Session()

In [None]:
# Get all annotation classes
annotation_classes = session.query(AnnotationClass).all()
print(annotation_classes)

In [None]:
class_id = 2

# Query all tiles for the current class
tiles = session.query(Tile).filter_by(annotation_class_id=class_id).all() ## filter by having a gt=True attribute
print(tiles)

In [None]:
import openslide
from sqlalchemy import inspect

for tile in tiles:
    image_id = tile.image_id
    gtpred = 'gt'  # or 'pred' based on your requirement
    table_name = build_annotation_table_name(image_id, class_id, gtpred == 'gt')

    # Check if the table exists
    inspector = inspect(engine)
    if not inspector.has_table(table_name):
        continue

    table = Table(table_name, db.metadata, autoload_with=engine)

    annotations = session.query(table).filter(
        table.c.polygon.ST_Within(tile.geom)
    ).all()
    
    
    if len(annotations) == 0:
        continue

    print("non zero", len(annotations))
    tpoly=shapely.wkb.loads(tile.geom.data)

    # Get the bounds of the polygon
    minx, miny, maxx, maxy = tpoly.bounds

    # Compute the width and height
    width = int(maxx - minx)
    height = int(maxy - miny)


    #--- image work
    image = session.query(Image).filter_by(id=image_id).first()
    if not image:
        continue

    image_path = image.path
    print(image_path)
    slide = openslide.OpenSlide("../"+image_path)

    # Extract the region defined by the tile
    region = slide.read_region((int(minx), int(miny)), 0, (width, height))

    io_filename = f"io_{tile.id}.png"
    region.save(io_filename)

    print(f"Width: {width}, Height: {height}")

    mask = np.zeros((height, width), dtype=np.uint8)

    for annotation in annotations:
        annotation_polygon = shapely.wkb.loads(annotation.polygon.data)
       
        # Translate the annotation polygon to the tile's coordinate system
        translated_polygon = shapely.affinity.translate(annotation_polygon, xoff=-minx, yoff=-miny)
        
        # Draw the translated polygon on the mask
        cv2.fillPoly(mask, [np.array(translated_polygon.exterior.coords, dtype=np.int32)], 1)
    
    
    mask_filename = f"mask_{tile.id}.png"
    cv2.imwrite(mask_filename, mask)





In [60]:
session.rollback()

In [None]:
session.close()

In [None]:
import matplotlib.pyplot as plt
import matplotlib.image as mpimg

num=324
# Load images
io_image = mpimg.imread(f'io_{num}.png')
mask_image = mpimg.imread(f'mask_{num}.png')

# Create a figure with two subplots
fig, axes = plt.subplots(1, 2, figsize=(10, 5))

# Display the images
axes[0].imshow(io_image)
axes[0].set_title('IO Image')
axes[0].axis('off')

axes[1].imshow(mask_image, cmap='gray')
axes[1].set_title('Mask Image')
axes[1].axis('off')

plt.show()

In [None]:
#!pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118

In [12]:
import openslide
from torch.utils.data import Dataset, DataLoader, IterableDataset
import numpy as np
from PIL import Image as PILImage

In [39]:

import torchvision.transforms as transforms

class TileDataset(IterableDataset):
    def __init__(self, tiles, transform=None):
        self.tiles = tiles
        self.cache = {} #to be convered to memcached
        self.transform = transform

    def __iter__(self):
        
        inspector = inspect(engine)
        print(self.cache.keys(),self.tiles)
        for tile in self.tiles:
            image_id = tile.image_id
            tile_id = tile.id
            cache_key = f"{image_id}_{tile_id}"

            if cache_key in self.cache:
                print("cache hit")
                io_image, mask_image = self.cache[cache_key]
            else:
            
                image_id = tile.image_id
                gtpred = 'gt'  # or 'pred' based on your requirement
                table_name = build_annotation_table_name(image_id, class_id, gtpred == 'gt')
                
                if not inspector.has_table(table_name):
                    continue
                
                table = Table(table_name, db.metadata, autoload_with=engine)

                
                annotations = session.query(table).filter(
                    table.c.polygon.ST_Within(tile.geom)
                ).all()

                
                if len(annotations) == 0:
                    continue
                
                tpoly = shapely.wkb.loads(tile.geom.data)

                # Get the bounds of the polygon
                minx, miny, maxx, maxy = tpoly.bounds

                # Compute the width and height
                width = int(maxx - minx)
                height = int(maxy - miny)

                #--- image work
                image = session.query(Image).filter_by(id=image_id).first()
                if not image:
                    continue

                image_path = image.path
                slide = openslide.OpenSlide("../" + image_path)

                # Extract the region defined by the tile
                region = slide.read_region((int(minx), int(miny)), 0, (width, height))
                io_image = region.convert("RGB")

                mask_image = np.zeros((height, width), dtype=np.uint8)

                for annotation in annotations:
                    annotation_polygon = shapely.wkb.loads(annotation.polygon.data)
                    translated_polygon = shapely.affinity.translate(annotation_polygon, xoff=-minx, yoff=-miny)
                    cv2.fillPoly(mask_image, [np.array(translated_polygon.exterior.coords, dtype=np.int32)], 1)

                # Convert PIL images to tensors
                io_image = transforms.ToTensor()(io_image)
#                mask_image = transforms.ToTensor()(mask_image)

                if self.transform:
                    print("MASK NOT TRANSFORMED")
                    io_image = self.transform(io_image)
 #                   mask_image = self.transform(|mask_image)
                
                self.cache[cache_key] = (io_image, mask_image)

            yield io_image, mask_image

# Example usage
transform = None  # Define any transformations if needed
dataset = TileDataset(tiles, transform=transform)


In [None]:

dataloader = DataLoader(dataset, batch_size=2, shuffle=False) #,num_workers=4)

# Iterate through the dataloader
for images, masks in dataloader:
    print(images.shape, masks.shape)
    break


In [None]:
dataset.cache.keys()

In [None]:
import matplotlib.pyplot as plt
import matplotlib.image as mpimg

# Iterate through the batch of images and masks
for i in range(images.shape[0]):
    io_image = images[i].permute(1, 2, 0).numpy()  # Shift the channel to the end
    mask_image = masks[i].squeeze().numpy()  # Remove the channel dimension for mask

    fig, axes = plt.subplots(1, 2, figsize=(10, 5))

    # Display the images
    axes[0].imshow(io_image)
    axes[0].set_title('IO Image')
    axes[0].axis('off')

    axes[1].imshow(mask_image, cmap='gray')
    axes[1].set_title('Mask Image')
    axes[1].axis('off')

    plt.show()



In [None]:
#!pip install segmentation-models-pytorch

In [None]:
# import segmentation_models_pytorch as smp
# import torch
# from tqdm import tqdm

# import torch.nn as nn
# import torch.optim as optim

# # Define the model
# model = smp.Unet(encoder_name="timm-mobilenetv3_small_100", encoder_weights="imagenet", in_channels=3, classes=1)

# # Define the loss function and optimizer
# criterion = nn.BCEWithLogitsLoss()
# optimizer = optim.Adam(model.parameters(), lr=0.001)

# # Move the model to GPU if available
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# print(device)

# model = model.to(device)

# # Training loop
# num_epochs = 10
# for epoch in range(num_epochs):
#     model.train()
#     running_loss = 0.0
#     for images, masks in tqdm(dataloader):
#         print("done loading batch")
#         images = images.to(device)
#         masks = masks.to(device)

#         # Zero the parameter gradients
#         optimizer.zero_grad()

#         # Forward pass
#         print("doing forward pass")
#         outputs = model(images)
#         loss = criterion(outputs, masks)
#         print("loss",loss)
#         # Backward pass and optimize
#         loss.backward()
#         optimizer.step()

#         running_loss += loss.item()

#     print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(dataloader)}")

# print("Training complete")

In [None]:
device

In [None]:
session.rollback()

In [None]:
import segmentation_models_pytorch as smp
import torch
from tqdm import tqdm

import torch.nn as nn
import torch.optim as optim

# Define the model
model = smp.Unet(encoder_name="timm-mobilenetv3_small_100", encoder_weights="imagenet", in_channels=3, classes=1)

# Define the loss function and optimizer
criterion = nn.BCEWithLogitsLoss(reduction='none')
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Move the model to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

model = model.to(device)

# Training loop
num_epochs = 10
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for images, masks in tqdm(dataloader):
        print("done loading batch")
        images = images.to(device)
        masks = masks.to(device)

        # Zero the parameter gradients
        optimizer.zero_grad()

        # Forward pass
        print("doing forward pass")
        outputs = model(images)

        loss = criterion(outputs, masks) 

        # Mask for positives and unlabeled
        positive_mask = (masks == 1).float()
        unlabeled_mask = (masks == 0).float()

        # Weighted loss
        positive_loss = 1.0  * (loss * positive_mask).mean()
        unlabeled_loss = .1* (loss * unlabeled_mask).mean()

        loss_total  =  positive_loss + unlabeled_loss

        
        print("losses:\t",loss_total,positive_mask.sum(),positive_loss,unlabeled_loss)
        # Backward pass and optimize
        loss_total.backward()
        optimizer.step()

        running_loss += loss_total.item()

    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(dataloader)}")

print("Training complete")

In [19]:
session.rollback()

In [None]:
import matplotlib.pyplot as plt
import matplotlib.image as mpimg

# Iterate through the batch of images and masks
for i in range(images.shape[0]):
    io_image = images[i].permute(1, 2, 0).numpy()  # Shift the channel to the end
    mask_image = masks[i].squeeze().numpy()  # Remove the channel dimension for mask

    fig, axes = plt.subplots(1, 2, figsize=(10, 5))

    # Display the images
    axes[0].imshow(io_image)
    axes[0].set_title('IO Image')
    axes[0].axis('off')

    axes[1].imshow(mask_image, cmap='gray')
    axes[1].set_title('Mask Image')
    axes[1].axis('off')

    plt.show()



In [None]:
np.unique(mask_image)