In [10]:
import sys; sys.path.insert(0, '..')

import math
import os
import pandas as pd
import numpy as np 

import torch

from torch.utils.data import Dataset, DataLoader
import SimpleITK as sitk
import nrrd
import pytorch_lightning as pl
from torchvision import transforms

import plotly.express as px
import plotly.graph_objects as go

In [7]:
mount_point = "/work/jprieto/data/remote/EGower/jprieto"
train_fn = os.path.join(mount_point, 'Analysis_Set_202208', 'trachoma_bsl_mtss_besrat_field_seg_train_202208_train.csv')
valid_fn = os.path.join(mount_point, 'Analysis_Set_202208', 'trachoma_bsl_mtss_besrat_field_seg_train_202208_eval.csv')
test_fn = os.path.join(mount_point, 'Analysis_Set_202208', 'trachoma_bsl_mtss_besrat_field_seg_test_202208.csv')

df_train = pd.read_csv(train_fn)
df_val = pd.read_csv(valid_fn)    
df_test = pd.read_csv(test_fn)



In [8]:
class TTDatasetSeg(Dataset):
    def __init__(self, df, mount_point="./", transform=None, img_column="img_path", seg_column="seg_path"):
        self.df = df
        self.mount_point = mount_point
        self.transform = transform
        self.img_column = img_column
        self.seg_column = seg_column
    def __len__(self):
        return len(self.df)
    def __getitem__(self, idx):
        row = self.df.loc[idx]
        img = os.path.join(self.mount_point, row[self.img_column])
        seg = os.path.join(self.mount_point, row[self.seg_column])
        img_np = torch.tensor(np.squeeze(sitk.GetArrayFromImage(sitk.ReadImage(img)))).to(torch.float32)
        seg_np = torch.tensor(np.squeeze(sitk.GetArrayFromImage(sitk.ReadImage(seg)))).to(torch.float32)

        d = {"img": img_np, "seg": seg_np}
        if (self.transform):
            d = self.transform(d)
        return d["img"], d["seg"]

In [9]:
class TrainTransformsSeg:
    def __init__(self):
        # image augmentation functions
        color_jitter = transforms.ColorJitter(brightness=[.5, 1.8], contrast=[0.5, 1.8], saturation=[.5, 1.8], hue=[-.2, .2])
        self.train_transform = Compose(
            [
                AsChannelFirstd(keys=["img"]),
                AddChanneld(keys=["seg"]),
                Resized(keys=["img", "seg"], spatial_size=[512, 512], mode=['area', 'nearest']),
                RandRotated(keys=["img", "seg"], prob=0.5, range_x=math.pi/2.0, range_y=math.pi/2.0, mode=["bilinear", "nearest"]),
                RandZoomd(keys=["img", "seg"], prob=0.5, min_zoom=0.8, max_zoom=1.2, mode=["area", "nearest"]),
                ScaleIntensityd(keys=["img"]),                
                Lambdad(keys=['img'], func=lambda x: color_jitter(x))
            ]
        )
    def __call__(self, inp):
        return self.train_transform(inp)

In [13]:
train_ds = TTDatasetSeg(df_train, mount_point, transform=TrainTransformsSeg())

img, seg = train_ds[0]

img_np = img.numpy()
seg_np = seg.numpy()

fig_img = go.Figure()

print(img_np.shape)
# fig_img = px.imshow(img_np, binary_string=True, binary_compression_level=5, binary_backend='pil')

fig_img.add_trace(go.Heatmap(z=seg_np, opacity=opacity, colorscale='rdbu'))

fig_img.update_layout(
    autosize=False,
    width=size,
    height=size
)
fig_img

AttributeError: 'Tensor' object has no attribute 'to_numpy'