In [1]:
!git clone https://github.com/baovin/ADNet-VIN.git

fatal: destination path 'ADNet-VIN' already exists and is not an empty directory.


In [2]:
!pip install SimpleITK monai



In [24]:
import argparse
import random
import torch
import torch.backends.cudnn as cudnn

def parse_arguments():
    parser = argparse.ArgumentParser()
    parser.add_argument('--data_root', default= "/content/ADNet-VIN/dataloading/", type=str)
    parser.add_argument('--save_root', default = "/content/ADNet-VIN/log_out", type=str)
    parser.add_argument('--dataset', default='CHAOST2', type=str)
    parser.add_argument('--n_sv', default=1, type=int)
    parser.add_argument('--fold', default=1, type=int)

    # Training specs.
    parser.add_argument('--max_slices', default=10, type=int)
    parser.add_argument('--workers', default=4, type=int)
    parser.add_argument('--steps', default=15000, type=int)
    parser.add_argument('--n_shot', default=1, type=int)
    parser.add_argument('--n_query', default=1, type=int)
    parser.add_argument('--n_way', default=1, type=int)
    parser.add_argument('--batch-size', default=1, type=int)
    parser.add_argument('--max_iterations', default=50, type=int)
    parser.add_argument('--lr', default=1e-3, type=float)
    parser.add_argument('--lr_gamma', default=0.95, type=float)
    parser.add_argument('--momentum', default=0.9, type=float)
    parser.add_argument('--weight-decay', default=0.0005, type=float)
    parser.add_argument('--seed', default=None, type=int)
    parser.add_argument('--bg_wt', default=0.1, type=float)
    parser.add_argument('--t_loss_scaler', default=1.0, type=float)
    parser.add_argument('--min_size', default=200, type=int)

    # parser.add_argument('--max_iterations', default=50, type=int)

    # Inference specs.
    parser.add_argument('--all_slices', default=True, type=bool)
    parser.add_argument('--EP1', default=True, type=bool)

    return parser.parse_known_args()  # Use parse_known_args to avoid unrecognized args issue

args, unknown = parse_arguments()  # 'unknown' will hold the unrecognized args like '-f'

# Deterministic setting for reproducibility.
if args.seed is not None:
    random.seed(args.seed)
    torch.manual_seed(args.seed)
    cudnn.deterministic = True

print(args)


Namespace(data_root='/content/ADNet-VIN/dataloading/', save_root='/content/ADNet-VIN/log_out', dataset='CHAOST2', n_sv=1, fold=1, max_slices=10, workers=4, steps=15000, n_shot=1, n_query=1, n_way=1, batch_size=1, max_iterations=50, lr=0.001, lr_gamma=0.95, momentum=0.9, weight_decay=0.0005, seed=None, bg_wt=0.1, t_loss_scaler=1.0, min_size=200, all_slices=True, EP1=True)


In [4]:
import sys
sys.path.append('/content/ADNet-VIN/dataloading/')

In [5]:
from dataset_specifics import get_folds, sample_xy, get_label_names

In [6]:
cd /content/

/content


In [7]:
!unzip superdix.zip

Archive:  superdix.zip
   creating: superdix/
  inflating: superdix/superpix-MIDDLE_1.nii.gz  
  inflating: superdix/superpix-MIDDLE_10.nii.gz  
  inflating: superdix/superpix-MIDDLE_13.nii.gz  
  inflating: superdix/superpix-MIDDLE_15.nii.gz  
  inflating: superdix/superpix-MIDDLE_19.nii.gz  
  inflating: superdix/superpix-MIDDLE_2.nii.gz  
  inflating: superdix/superpix-MIDDLE_20.nii.gz  
  inflating: superdix/superpix-MIDDLE_21.nii.gz  
  inflating: superdix/superpix-MIDDLE_22.nii.gz  
  inflating: superdix/superpix-MIDDLE_3.nii.gz  
  inflating: superdix/superpix-MIDDLE_31.nii.gz  
  inflating: superdix/superpix-MIDDLE_32.nii.gz  
  inflating: superdix/superpix-MIDDLE_33.nii.gz  
  inflating: superdix/superpix-MIDDLE_34.nii.gz  
  inflating: superdix/superpix-MIDDLE_36.nii.gz  
  inflating: superdix/superpix-MIDDLE_37.nii.gz  
  inflating: superdix/superpix-MIDDLE_38.nii.gz  
  inflating: superdix/superpix-MIDDLE_39.nii.gz  
  inflating: superdix/superpix-MIDDLE_5.nii.gz  
  inflat

In [8]:
import os
print(os.getcwd())

/content


In [9]:
cd /content/ADNet-VIN/dataloading/

/content/ADNet-VIN/dataloading


In [10]:
import torch
from torch.utils.data import Dataset
import glob
import os
import SimpleITK as sitk
import random
import numpy as np
# from .dataset_specifics import *
from monai.transforms.spatial.dictionary import Rand3DElasticd
from collections import defaultdict


class TestDataset(Dataset):

    def __init__(self, args):

        # reading the paths
        if args.dataset == 'CMR':
            self.image_dirs = glob.glob(os.path.join(args.data_root, 'cmr_MR_normalized/image*'))
        elif args.dataset == 'CHAOST2':
            self.image_dirs = glob.glob(os.path.join(args.data_root, 'chaos_MR_T2_normalized/image*'))
        self.image_dirs = sorted(self.image_dirs, key=lambda x: int(x.split('_')[-1].split('.nii.gz')[0]))

        # remove test fold!
        self.FOLD = get_folds(args.dataset)
        self.image_dirs = [elem for idx, elem in enumerate(self.image_dirs) if idx in self.FOLD[args.fold]]

        # split into support/query
        self.support_dir = self.image_dirs[-1]
        self.image_dirs = self.image_dirs[:-1]  # remove support
        self.label = None

        # evaluation protocol
        self.EP1 = args.EP1

    def __len__(self):
        return len(self.image_dirs)

    def __getitem__(self, idx):

        img_path = self.image_dirs[idx]
        img = sitk.GetArrayFromImage(sitk.ReadImage(img_path))
        img = (img - img.mean()) / img.std()
        img = np.stack(1 * [img], axis=0)

        lbl = sitk.GetArrayFromImage(
            sitk.ReadImage(img_path.split('image_')[0] + 'label_' + img_path.split('image_')[-1]))
        lbl[lbl == 200] = 1
        lbl[lbl == 500] = 2
        lbl[lbl == 600] = 3
        lbl = 1 * (lbl == self.label)

        sample = {'id': img_path}

        # Evaluation protocol 1.
        if self.EP1:
            idx = lbl.sum(axis=(1, 2)) > 0
            sample['image'] = torch.from_numpy(img[idx])
            sample['label'] = torch.from_numpy(lbl[idx])

        # Evaluation protocol 2 (default).
        else:
            sample['image'] = torch.from_numpy(img)
            sample['label'] = torch.from_numpy(lbl)

        return sample

    def get_support_index(self, n_shot, C):
        """
        Selecting intervals according to Ouyang et al.
        """
        if n_shot == 1:
            pcts = [0.5]
        else:
            half_part = 1 / (n_shot * 2)
            part_interval = (1.0 - 1.0 / n_shot) / (n_shot - 1)
            pcts = [half_part + part_interval * ii for ii in range(n_shot)]

        return (np.array(pcts) * C).astype('int')

    def getSupport(self, label=None, all_slices=True, N=None):
        if label is None:
            raise ValueError('Need to specify label class!')

        img_path = self.support_dir
        img = sitk.GetArrayFromImage(sitk.ReadImage(img_path))
        img = (img - img.mean()) / img.std()
        img = np.stack(1 * [img], axis=0)

        lbl = sitk.GetArrayFromImage(
            sitk.ReadImage(img_path.split('image_')[0] + 'label_' + img_path.split('image_')[-1]))
        lbl[lbl == 200] = 1
        lbl[lbl == 500] = 2
        lbl[lbl == 600] = 3
        lbl = 1 * (lbl == label)

        sample = {}
        if all_slices:

            sample['image'] = torch.from_numpy(img)[None]
            sample['label'] = torch.from_numpy(lbl)[None]

            # target = np.where(lbl.sum(axis=(-2, -1)) > 0)[0]
            # mask = np.zeros(lbl.shape) == 1
            # mask[target.astype('float').mean().astype('int')] = True
            # sample['label'] = torch.from_numpy((mask*1)*lbl)[None]

        else:
            # select N labeled slices
            if N is None:
                raise ValueError('Need to specify number of labeled slices!')
            idx = lbl.sum(axis=(1, 2)) > 0
            idx_ = self.get_support_index(N, idx.sum())

            sample['image'] = torch.from_numpy(img[:, idx][:, idx_])[None]
            sample['label'] = torch.from_numpy(lbl[idx][idx_])[None]

        return sample


class TrainDataset(Dataset):

    def __init__(self, args):
        self.n_shot = args.n_shot
        self.n_way = args.n_way
        self.n_query = args.n_query
        self.n_sv = args.n_sv
        self.max_iter = args.max_iterations
        self.min_size = args.min_size
        self.max_slices = args.max_slices

        # reading the paths (leaving the reading of images into memory to __getitem__)
        if args.dataset == 'CMR':
            self.image_dirs = glob.glob(os.path.join(args.data_root, 'cmr_MR_normalized/image*'))
        elif args.dataset == 'CHAOST2':
            self.image_dirs = glob.glob(os.path.join(args.data_root, 'chaos_MR_T2_normalized/image*'))
        self.image_dirs = sorted(self.image_dirs, key=lambda x: int(x.split('_')[-1].split('.nii.gz')[0]))
        self.sprvxl_dirs = glob.glob(os.path.join(args.data_root, 'superdix/', 'super*'))
        # self.sprvxl_dirs = glob.glob(os.path.join(args.data_root, 'supervoxels_' + str(args.n_sv), 'super*'))
        self.sprvxl_dirs = sorted(self.sprvxl_dirs, key=lambda x: int(x.split('_')[-1].split('.nii.gz')[0]))

        # remove test fold!
        self.FOLD = get_folds(args.dataset)
        self.image_dirs = [elem for idx, elem in enumerate(self.image_dirs) if idx not in self.FOLD[args.fold]]
        self.sprvxl_dirs = [elem for idx, elem in enumerate(self.sprvxl_dirs) if idx not in self.FOLD[args.fold]]
        self.N = len(self.image_dirs)

        # read images
        self.images = {}
        self.sprvxls = {}
        self.valid_spr_slices = {}
        for image_dir, sprvxl_dir in zip(self.image_dirs, self.sprvxl_dirs):
            img = sitk.ReadImage(image_dir)
            self.res = img.GetSpacing()
            img = sitk.GetArrayFromImage(img)
            self.images[image_dir] = torch.from_numpy(img)
            spr = torch.from_numpy(sitk.GetArrayFromImage(sitk.ReadImage(sprvxl_dir)))
            self.sprvxls[sprvxl_dir] = spr

            unique = list(torch.unique(spr))
            unique.remove(0)
            self.valid_spr_slices[image_dir] = []
            for val in unique:
                spr_val = (spr == val)

                n_slices = min(spr_val.shape[0], self.max_slices)
                sample_list = []
                for r in range(spr_val.shape[0] - (n_slices - 1)):
                    sample_idx = torch.arange(r, r + n_slices).tolist()
                    candidate = spr_val[sample_idx]
                    if candidate.sum() > self.min_size:
                        sample_list.append(sample_idx)
                if len(sample_list) > 0:
                    self.valid_spr_slices[image_dir].append((val, sample_list))

        # set transformation details
        rad = 5 * (np.pi / 180)
        self.rand_3d_elastic = Rand3DElasticd(
            keys=("img", "seg"),
            mode=("bilinear", "nearest"),
            sigma_range=(5, 5),
            magnitude_range=(0, 0),
            prob=1.0,  # because probability controlled by this class
            rotate_range=(rad, rad, rad),
            shear_range=(rad, rad, rad),
            translate_range=(5, 5, 1),
            scale_range=((-0.1, 0.2), (-0.1, 0.2), (-0.1, 0.2)),
            # as_tensor_output=True,
            device='cpu')

    def __len__(self):
        return self.max_iter

    def gamma_tansform(self, img):
        gamma_range = (0.5, 1.5)
        gamma = torch.rand(1) * (gamma_range[1] - gamma_range[0]) + gamma_range[0]
        cmin = img.min()
        irange = (img.max() - cmin + 1e-5)

        img = img - cmin + 1e-5
        img = irange * torch.pow(img * 1.0 / irange, gamma)
        img = img + cmin

        return img

    def __getitem__(self, idx):

        # sample patient idx
        pat_idx = random.choice(range(len(self.image_dirs)))

        # get image/supervoxel volume from dictionary
        img = self.images[self.image_dirs[pat_idx]]
        sprvxl = self.sprvxls[self.sprvxl_dirs[pat_idx]]

        # normalize
        img = (img - img.mean()) / img.std()

        # sample supervoxel
        valid = self.valid_spr_slices[self.image_dirs[pat_idx]]
        cls_idx, candidates = valid[random.randint(0, len(valid) - 1)]

        sprvxl = 1 * (sprvxl == cls_idx)

        sup_lbl = torch.clone(sprvxl)
        qry_lbl = torch.clone(sprvxl)

        sup_img = torch.clone(img)
        qry_img = torch.clone(img)

        # gamma transform
        if np.random.random(1) > 0.5:
            qry_img = self.gamma_tansform(qry_img)
        else:
            sup_img = self.gamma_tansform(sup_img)

        # geom transform
        if np.random.random(1) > 0.5:
            res = self.rand_3d_elastic({"img": qry_img.permute(1, 2, 0),
                                        "seg": qry_lbl.permute(1, 2, 0)})

            qry_img = res["img"].permute(2, 0, 1)
            qry_lbl = res["seg"].permute(2, 0, 1)

            # support not tformed
            constant_s = random.randint(0, len(candidates) - 1)
            idx_s = candidates[constant_s]

            k = 50
            constant_q = constant_s + random.randint(-min(constant_s, k), min(len(candidates) - constant_s - 1, k))
            idx_q = candidates[constant_q]

        else:
            res = self.rand_3d_elastic({"img": sup_img.permute(1, 2, 0),
                                        "seg": sup_lbl.permute(1, 2, 0)})

            sup_img_ = res["img"].permute(2, 0, 1)
            sup_lbl_ = res["seg"].permute(2, 0, 1)

            constant_q = random.randint(0, len(candidates) - 1)
            idx_q = candidates[constant_q]

            k = 50
            constant_s = constant_q + random.randint(-min(constant_q, k), min(len(candidates) - constant_q - 1, k))
            idx_s = candidates[constant_s]
            if sup_lbl_[idx_s].sum() > 0:
                sup_img = sup_img_
                sup_lbl = sup_lbl_

        sup_lbl = sup_lbl[idx_s]
        qry_lbl = qry_lbl[idx_q]

        sup_img = sup_img[idx_s]
        qry_img = qry_img[idx_q]

        b = 215
        k = 0
        horizontal_s, vertical_s = sample_xy(sup_lbl, k=k, b=b)
        horizontal_q, vertical_q = sample_xy(qry_lbl, k=k, b=b)

        sup_img = sup_img[:, horizontal_s:horizontal_s + b, vertical_s:vertical_s + b]
        sup_lbl = sup_lbl[:, horizontal_s:horizontal_s + b, vertical_s:vertical_s + b]
        qry_img = qry_img[:, horizontal_q:horizontal_q + b, vertical_q:vertical_q + b]
        qry_lbl = qry_lbl[:, horizontal_q:horizontal_q + b, vertical_q:vertical_q + b]

        sample = {'support_images': torch.stack(1 * [sup_img], dim=0),
                  'support_fg_labels': sup_lbl[None],
                  'query_images': torch.stack(1 * [qry_img], dim=0),
                  'query_labels': qry_lbl}

        return sample


In [11]:
train_dataset = TrainDataset(args)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True,
                                               num_workers=args.workers, pin_memory=True, drop_last=True)
print(len(train_dataset))
data = train_dataset[0]
print(data['support_images'].shape)
print(data['support_fg_labels'].shape)
print(data['query_images'].shape)
print(data['query_labels'].shape)
print(len(train_loader))



50
torch.Size([1, 10, 215, 215])
torch.Size([1, 10, 215, 215])
torch.Size([1, 10, 215, 215])
torch.Size([10, 215, 215])
50


In [12]:
sys.path.append('/content/ADNet-VIN/')

In [13]:
!pip install wandb



In [14]:
!pip install gdown



In [15]:
cd /content/

/content


In [16]:
!gdown 1p80RJsghFIKBSLKgtRG94LE38OGY5h4y

Downloading...
From (original): https://drive.google.com/uc?id=1p80RJsghFIKBSLKgtRG94LE38OGY5h4y
From (redirected): https://drive.google.com/uc?id=1p80RJsghFIKBSLKgtRG94LE38OGY5h4y&confirm=t&uuid=169f12b6-fd48-4397-806d-99b76a02781c
To: /content/r3d101_KM_200ep.pth
100% 700M/700M [00:12<00:00, 54.7MB/s]


In [17]:
!mv r3d101_KM_200ep.pth resnext-101-kinetics.pth

In [27]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import math
from functools import partial

__all__ = ['ResNeXt', 'resnet50', 'resnet101']


def conv3x3x3(in_planes, out_planes, stride=1):
    # 3x3x3 convolution with padding
    return nn.Conv3d(
        in_planes,
        out_planes,
        kernel_size=3,
        stride=stride,
        padding=1,
        bias=False)


def downsample_basic_block(x, planes, stride):
    out = F.avg_pool3d(x, kernel_size=1, stride=stride)
    zero_pads = torch.Tensor(
        out.size(0), planes - out.size(1), out.size(2), out.size(3),
        out.size(4)).zero_()
    if isinstance(out.data, torch.cuda.FloatTensor):
        zero_pads = zero_pads.cuda()

    out = Variable(torch.cat([out.data, zero_pads], dim=1))

    return out


class ResNeXtBottleneck(nn.Module):
    expansion = 2

    def __init__(self, inplanes, planes, cardinality, stride=1,
                 downsample=None, dilation=1):
        super(ResNeXtBottleneck, self).__init__()
        mid_planes = cardinality * int(planes / 32)
        self.conv1 = nn.Conv3d(inplanes, mid_planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm3d(mid_planes)
        self.conv2 = nn.Conv3d(mid_planes, mid_planes, kernel_size=3, stride=stride, padding=dilation,
                               groups=cardinality, dilation=dilation, bias=False)
        self.bn2 = nn.BatchNorm3d(mid_planes)
        self.conv3 = nn.Conv3d(
            mid_planes, planes * self.expansion, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm3d(planes * self.expansion)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out


class ResNeXt(nn.Module):

    def __init__(self,
                 block,
                 layers,
                 shortcut_type='B',
                 cardinality=32,
                 replace_stride_with_dilation=None):
        self.inplanes = 64
        self.dilation = torch.tensor([1, 1, 1])
        if replace_stride_with_dilation is None:
            # each element in the tuple indicates if we should replace
            # the 2x2 stride with a dilated convolution instead
            replace_stride_with_dilation = [False, False, False]
        super(ResNeXt, self).__init__()
        self.dropout = nn.Dropout(p=0.2)
        self.conv1 = nn.Conv3d(3, 64, kernel_size=7, stride=(1, 2, 2), padding=(3, 3, 3), bias=False)  # NOTE: this is being over-written in the main script!
        self.bn1 = nn.BatchNorm3d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool3d(kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1))
        self.layer1 = self._make_layer(block, 128, layers[0], shortcut_type,
                                       cardinality)
        self.layer2 = self._make_layer(
            block, 256, layers[1], shortcut_type, cardinality, stride=(1, 2, 2), dilate=replace_stride_with_dilation[0])
        self.layer3 = self._make_layer(
            block, 512, layers[2], shortcut_type, cardinality, stride=(1, 1, 1), dilate=replace_stride_with_dilation[1]) # (1, 2, 2) or (2 ,2, 2)
        self.layer4 = self._make_layer(
            block, 1024, layers[3], shortcut_type, cardinality, stride=(1, 1, 1), dilate=replace_stride_with_dilation[2])# (1, 2, 2) or (2 ,2, 2)

        for m in self.modules():
            if isinstance(m, nn.Conv3d):
                m.weight = nn.init.kaiming_normal_(m.weight, mode='fan_out')
            elif isinstance(m, nn.BatchNorm3d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

    def _make_layer(self,
                    block,
                    planes,
                    blocks,
                    shortcut_type,
                    cardinality,
                    stride=1,
                    dilate=False):
        downsample = None
        previous_dilation = self.dilation
        if dilate:
            self.dilation *= torch.tensor(stride)
            stride = 1

        if stride != 1 or self.inplanes != planes * block.expansion:
            if shortcut_type == 'A':
                downsample = partial(
                    downsample_basic_block,
                    planes=planes * block.expansion,
                    stride=stride)
            else:
                downsample = nn.Sequential(
                    nn.Conv3d(
                        self.inplanes,
                        planes * block.expansion,
                        kernel_size=1,
                        stride=stride,
                        bias=False), nn.BatchNorm3d(planes * block.expansion))

        layers = []
        layers.append(
            block(self.inplanes, planes, cardinality, stride, downsample, dilation=previous_dilation))
        self.inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes, cardinality, dilation=self.dilation))

        return nn.Sequential(*layers)

    def forward(self, x):

        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        return x


def get_fine_tuning_parameters(model, ft_portion):
    if ft_portion == "complete":
        return model.parameters()

    elif ft_portion == "last_layer":
        ft_module_names = []
        ft_module_names.append('fc')

        parameters = []
        for k, v in model.named_parameters():
            for ft_module in ft_module_names:
                if ft_module in k:
                    parameters.append({'params': v})
                    break
            else:
                parameters.append({'params': v, 'lr': 0.0})
        return parameters

    else:
        raise ValueError("Unsupported ft_portion: 'complete' or 'last_layer' expected")


def resnext50(**kwargs):
    """Constructs a ResNet-50 model.
    """
    model = ResNeXt(ResNeXtBottleneck, [3, 4, 6, 3], **kwargs)
    return model


def resnext101(**kwargs):
    """Constructs a ResNet-101 model.
    """
    model = ResNeXt(ResNeXtBottleneck, [3, 4, 23, 3], **kwargs)
    model = nn.DataParallel(model, device_ids=[0, ])
    load = True
    if load:
        print('Loading pre-trained weights!')
        # pretrained_dict = torch.load('./pretrained_model/kinetics_resnext_101_RGB_16_best.pth', map_location='cpu')
        # pretrained_dict = torch.load('./pretrained_model/jester_resnext_101_RGB_16_best.pth', map_location='cpu')
        pretrained_dict = torch.load('/content/resnext-101-kinetics.pth', map_location='cpu')

        model_dict = model.state_dict()

        # 1. filter out unnecessary keys
        pretrained_dict = {k: v for k, v in pretrained_dict["state_dict"].items() if k in model_dict}

        # 2. overwrite entries in the existing state dict
        model_dict.update(pretrained_dict)
        # 3. load the new state dict
        model.load_state_dict(model_dict)

    return model


def resnext152(**kwargs):
    """Constructs a ResNet-101 model.
    """
    model = ResNeXt(ResNeXtBottleneck, [3, 8, 36, 3], **kwargs)
    return model


# import SimpleITK as sitk
#
# encoder101 = resnext101(replace_stride_with_dilation=[False, True, True])
# #encoder50 = resnext50(sample_size=None, sample_duration=None, replace_stride_with_dilation=[False, True, True])
# img = sitk.GetArrayFromImage(
#     sitk.ReadImage('/Users/sha168/gitRepos/springfield_files2/data/CMR/cmr_MR_normalized/image_35.nii.gz'))
# img = torch.from_numpy(img)
# img = img[None, None].repeat([1, 3, 1, 1, 1])
# fts101 = encoder101(img.float())
# #fts50 = encoder50(img.float())
# print('ResNeXt-101', img.shape, fts101.shape)
# #print('ResNeXt-50', img.shape, fts50.shape)


In [26]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parameter import Parameter
# from .backbone.resnext3D import resnext101


class FewShotSeg(nn.Module):

    def __init__(self, args):
        super().__init__()

        # Encoder
        self.encoder = nn.Sequential(resnext101(replace_stride_with_dilation=[False, True, True]),
                                     nn.Conv3d(2048, 256, kernel_size=1, stride=1, bias=False))
        self.device = torch.device('cuda')
        self.t = Parameter(torch.Tensor([-10.0]))
        self.scaler = 20.0
        self.criterion = nn.NLLLoss()

    def forward(self, supp_imgs, fore_mask, qry_imgs, train=False, t_loss_scaler=1):
        """
        Args:
            supp_imgs: support images
                way x shot x [B x 3 x H x W], list of lists of tensors
            fore_mask: foreground masks for support images
                way x shot x [B x H x W], list of lists of tensors
            back_mask: background masks for support images
                way x shot x [B x H x W], list of lists of tensors
            qry_imgs: query images
                N x [B x 3 x H x W], list of tensors
        """

        n_ways = len(supp_imgs)
        self.n_shots = len(supp_imgs[0])
        n_queries = len(qry_imgs)
        batch_size = supp_imgs[0][0].shape[0]
        img_size = qry_imgs[0].shape[-3:]

        # ###### Extract features ######
        s_imgs_concat = torch.cat([torch.stack(way, dim=0) for way in supp_imgs], dim=0)
        q_imgs_concat = torch.cat(qry_imgs, dim=0)

        s_img_fts = self.encoder(s_imgs_concat.repeat([1, 3, 1, 1, 1]))
        q_img_fts = self.encoder(q_imgs_concat.repeat([1, 3, 1, 1, 1]))

        s_fts_size = s_img_fts.shape[-3:]
        q_fts_size = q_img_fts.shape[-3:]

        supp_fts = s_img_fts.view(
            n_ways, self.n_shots, batch_size, -1, *s_fts_size)  # Wa x Sh x B x C x D' x H' x W'
        qry_fts = q_img_fts.view(
            n_queries, batch_size, -1, *q_fts_size)  # N x B x C x D' x H' x W'

        fore_mask = torch.stack([torch.stack(way, dim=0)
                                 for way in fore_mask], dim=0)  # Wa x Sh x B x H' x W'

        ###### Compute loss ######
        align_loss = torch.zeros(1).to(self.device)
        outputs = []
        for epi in range(batch_size):

            ###### Extract prototypes ######
            supp_fts_ = [[self.getFeatures(supp_fts[way, shot, [epi]],
                                           fore_mask[way, shot, [epi]])
                          for shot in range(self.n_shots)] for way in range(n_ways)]

            fg_prototypes = self.getPrototype(supp_fts_)

            ###### Compute anom. scores ######
            anom_s = [self.negSim(qry_fts[:, epi], prototype) for prototype in fg_prototypes]

            ###### Get threshold #######
            self.thresh_pred = [self.t for _ in range(n_ways)]
            self.t_loss = self.t / self.scaler

            ###### Get predictions #######
            pred = self.getPred(anom_s, self.thresh_pred)  # N x Wa x H' x W'

            pred_ups = F.interpolate(pred, size=img_size, mode='trilinear', align_corners=True)
            pred_ups = torch.cat((1.0 - pred_ups, pred_ups), dim=1)

            outputs.append(pred_ups)

            ###### Prototype alignment loss ######
            if train:
                align_loss_epi = self.alignLoss(qry_fts[:, epi], torch.cat((1.0 - pred, pred), dim=1),
                                                supp_fts[:, :, epi],
                                                fore_mask[:, :, epi])
                align_loss += align_loss_epi

        output = torch.stack(outputs, dim=1)  # N x B x (1 + Wa) x H x W
        output = output.view(-1, *output.shape[2:])
        return output, (align_loss / batch_size), (t_loss_scaler * self.t_loss)

    def negSim(self, fts, prototype):
        """
        Calculate the distance between features and prototypes

        Args:
            fts: input features
                expect shape: N x C x H x W
            prototype: prototype of one semantic class
                expect shape: 1 x C
        """

        sim = - F.cosine_similarity(fts, prototype[..., None, None, None], dim=1) * self.scaler

        return sim

    def getFeatures(self, fts, mask):
        """
        Extract foreground and background features via masked average pooling

        Args:
            fts: input features, expect shape: 1 x C x H' x W'
            mask: binary mask, expect shape: 1 x H x W
        """

        #fts = F.interpolate(fts, size=mask.shape[-3:], mode='trilinear')
        mask = F.interpolate(mask[None], size=fts.shape[-3:], mode='nearest')[0]

        # masked fg features
        masked_fts = torch.sum(fts * mask[None, ...], dim=(2, 3, 4)) \
                     / (mask[None, ...].sum(dim=(2, 3, 4)) + 1e-5)  # 1 x C

        return masked_fts

    def getPrototype(self, fg_fts):
        """
        Average the features to obtain the prototype

        Args:
            fg_fts: lists of list of foreground features for each way/shot
                expect shape: Wa x Sh x [1 x C]
            bg_fts: lists of list of background features for each way/shot
                expect shape: Wa x Sh x [1 x C]
        """

        n_ways, n_shots = len(fg_fts), len(fg_fts[0])
        fg_prototypes = [torch.sum(torch.cat([tr for tr in way], dim=0), dim=0, keepdim=True) / n_shots for way in
                         fg_fts]  ## concat all fg_fts

        return fg_prototypes

    def alignLoss(self, qry_fts, pred, supp_fts, fore_mask):

        n_ways, n_shots = len(fore_mask), len(fore_mask[0])
        fore_mask = F.interpolate(fore_mask, size=qry_fts.shape[-3:], mode='nearest')

        # Mask and get query prototype
        pred_mask = pred.argmax(dim=1, keepdim=True)  # N x 1 x H' x W'
        binary_masks = [pred_mask == i for i in range(1 + n_ways)]
        skip_ways = [i for i in range(n_ways) if binary_masks[i + 1].sum() == 0]
        pred_mask = torch.stack(binary_masks, dim=1).float()  # N x (1 + Wa) x 1 x H' x W'

        qry_prototypes = torch.sum(qry_fts.unsqueeze(1) * pred_mask, dim=(0, 3, 4, 5))
        qry_prototypes = qry_prototypes / (pred_mask.sum((0, 3, 4, 5)) + 1e-5)  # (1 + Wa) x C

        # Compute the support loss
        loss = torch.zeros(1).to(self.device)
        for way in range(n_ways):
            if way in skip_ways:
                continue
            # Get the query prototypes
            for shot in range(n_shots):
                img_fts = supp_fts[way, [shot]]
                supp_sim = self.negSim(img_fts, qry_prototypes[[way + 1]])

                pred = self.getPred([supp_sim], [self.thresh_pred[way]])  # N x Wa x H' x W'
                pred_ups = torch.cat((1.0 - pred, pred), dim=1)

                # Construct the support Ground-Truth segmentation
                supp_label = torch.full_like(fore_mask[way, shot], 255, device=img_fts.device).long()
                supp_label[fore_mask[way, shot] == 1] = 1
                supp_label[fore_mask[way, shot] == 0] = 0

                # Compute Loss
                eps = torch.finfo(torch.float32).eps
                log_prob = torch.log(torch.clamp(pred_ups, eps, 1 - eps))
                loss += self.criterion(log_prob, supp_label[None, ...].long()) / n_shots / n_ways

        return loss

    def getPred(self, sim, thresh):
        pred = []

        for s, t in zip(sim, thresh):
            pred.append(1.0 - torch.sigmoid(0.5 * (s - t)))

        return torch.stack(pred, dim=1)  # N x Wa x H' x W'




In [20]:
#!/usr/bin/env python

import argparse
import time
import random
import numpy as np

import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
import torch.utils.data.distributed
from torch.optim.lr_scheduler import MultiStepLR
import tqdm

# from models.fewshot_anom_3D import FewShotSeg
# from dataloading.datasets_3D import TrainDataset, TestDataset
from torch.utils.data import DataLoader
from utils import *
from dataloading.dataset_specifics import *

def main():
    # args = parse_arguments()

    # Deterministic setting for reproducability.
    # if args.seed is not None:
    #     random.seed(args.seed)
    #     torch.manual_seed(args.seed)
    #     cudnn.deterministic = True

    # Set up logging.
    logger = set_logger(args.save_root, 'train.log')
    logger.info(args)

    # Setup the path to save.
    args.save_model_path = os.path.join(args.save_root, 'model.pth')

    # Init model.
    model = FewShotSeg(args)
    model = nn.DataParallel(model.cuda())

    # Init optimizer.
    optimizer = torch.optim.SGD(model.parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
    milestones = [(ii + 1) * 1000 for ii in range(args.steps // 1000 - 1)]
    scheduler = MultiStepLR(optimizer, milestones=milestones, gamma=args.lr_gamma)

    # Define loss function.
    my_weight = torch.FloatTensor([args.bg_wt, 1.0]).cuda()
    criterion = nn.NLLLoss(ignore_index=255, weight=my_weight)

    # Enable cuDNN benchmark mode to select the fastest convolution algorithm.
    cudnn.enabled = True
    cudnn.benchmark = True

    # Define data set and loader.
    train_dataset = TrainDataset(args)
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True,
                                               num_workers=args.workers, pin_memory=True, drop_last=True)
    logger.info('  Training on images not in test fold: ' +
                str([elem[len(args.data_root):] for elem in train_dataset.image_dirs]))

    # Start training.
    sub_epochs = args.steps // args.max_iterations
    logger.info('  Start training ...')

    for epoch in range(sub_epochs):
        # Train.
        batch_time, data_time, losses, q_loss, align_loss, t_loss = train(train_loader, model, criterion, optimizer,
                                                                          scheduler, epoch, args)

        # Log
        logger.info('============== Epoch [{}] =============='.format(epoch))
        logger.info('  Batch time: {:6.3f}'.format(batch_time))
        logger.info('  Loading time: {:6.3f}'.format(data_time))
        logger.info('  Total Loss  : {:.5f}'.format(losses))
        logger.info('  Query Loss  : {:.5f}'.format(q_loss))
        logger.info('  Align Loss  : {:.5f}'.format(align_loss))
        logger.info('  Threshold Loss  : {:.5f}'.format(t_loss))

        if epoch == 29:
            torch.save(model.state_dict(), args.save_model_path)

    # Save trained model.
    logger.info('  Saving model ...')
    torch.save(model.state_dict(), args.save_model_path)


def train(train_loader, model, criterion, optimizer, scheduler, epoch, args):
    batch_time = AverageMeter('Time', ':6.3f')
    data_time = AverageMeter('Data', ':6.3f')
    losses = AverageMeter('Loss', ':.4f')
    q_loss = AverageMeter('Query loss', ':.4f')
    a_loss = AverageMeter('Align loss', ':.4f')
    t_loss = AverageMeter('Threshold loss', ':.4f')

    # Train mode.
    model.train()

    end = time.time()
    # for i, sample in enumerate(train_loader):
    for i, sample in enumerate(tqdm.tqdm(train_loader)):

        # Extract episode data.
        support_images = [[shot[None].float().cuda() for shot in way]
                          for way in sample['support_images']]
        support_fg_mask = [[shot[None].float().cuda() for shot in way]
                           for way in sample['support_fg_labels']]

        query_images = [query_image.float().cuda() for query_image in sample['query_images']]
        query_labels = torch.cat([query_label.long().cuda() for query_label in sample['query_labels']], dim=0)

        # Log loading time.
        data_time.update(time.time() - end)

        # Compute outputs and losses.
        query_pred, align_loss, thresh_loss = model(support_images, support_fg_mask, query_images,
                                                    train=True, t_loss_scaler=args.t_loss_scaler)

        query_loss = criterion(torch.log(torch.clamp(query_pred, torch.finfo(torch.float32).eps,
                                                     1 - torch.finfo(torch.float32).eps)), query_labels[None])
        loss = query_loss + align_loss + thresh_loss

        # compute gradient and do SGD step
        for param in model.parameters():
            param.grad = None

        loss.backward()
        optimizer.step()
        scheduler.step()

        # Log loss.
        losses.update(loss.item(), query_pred.size(0))
        q_loss.update(query_loss.item(), query_pred.size(0))
        a_loss.update(align_loss.item(), query_pred.size(0))
        t_loss.update(thresh_loss.item(), query_pred.size(0))

        # Log elapsed time.
        batch_time.update(time.time() - end)
        end = time.time()

    return batch_time.avg, data_time.avg, losses.avg, q_loss.avg, a_loss.avg, t_loss.avg


if __name__ == '__main__':
    main()


FileExistsError: [Errno 17] File exists: '/content/ADNet-VIN/log_out'

In [21]:
import os

# Define the folder path to remove
folder_path = '/content/ADNet-VIN/log_out'

# Remove the folder using a shell command
os.system(f'rm -r {folder_path}')

print(f"Folder '{folder_path}' has been removed.")

Folder '/content/ADNet-VIN/log_out' has been removed.


In [None]:
def get_label_names(dataset):
    label_names = {}
    if dataset == 'CMR':
        label_names[0] = 'BG'
        label_names[1] = 'LV-MYO'
        label_names[2] = 'LV-BP'
        label_names[3] = 'RV'

    elif dataset == 'CHAOST2':
        label_names[0] = 'BG'
        label_names[1] = 'LIVER'
        label_names[2] = 'RK'
        label_names[3] = 'LK'
        label_names[4] = 'SPLEEN'

    return label_names


In [22]:
#!/usr/bin/env python

import argparse
import time
import random
import numpy as np

import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
import torch.utils.data.distributed
from torch.optim.lr_scheduler import MultiStepLR
import tqdm

# from models.fewshot_anom_3D import FewShotSeg
# from dataloading.datasets_3D import TrainDataset, TestDataset
from torch.utils.data import DataLoader
from utils import *
from dataloading.dataset_specifics import *

def main():
    # args = parse_arguments()

    # Deterministic setting for reproducability.
    # if args.seed is not None:
    #     random.seed(args.seed)
    #     torch.manual_seed(args.seed)
    #     cudnn.deterministic = True

    # Set up logging.
    logger = set_logger(args.save_root, 'train.log')
    logger.info(args)

    # Setup the path to save.
    args.save_model_path = os.path.join(args.save_root, 'model.pth')

    # Init model.
    model = FewShotSeg(args)
    model = torch.nn.DataParallel(model.cuda())
    # model = model.cuda()

    # Init optimizer.
    optimizer = torch.optim.SGD(model.parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
    milestones = [(ii + 1) * 1000 for ii in range(args.steps // 1000 - 1)]
    scheduler = MultiStepLR(optimizer, milestones=milestones, gamma=args.lr_gamma)

    # Define loss function.
    my_weight = torch.FloatTensor([args.bg_wt, 1.0]).cuda()
    criterion = nn.NLLLoss(ignore_index=255, weight=my_weight)

    # Enable cuDNN benchmark mode to select the fastest convolution algorithm.
    cudnn.enabled = True
    cudnn.benchmark = True

    # Define data set and loader.
    train_dataset = TrainDataset(args)
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True,
                                               num_workers=args.workers, pin_memory=True, drop_last=True)

    test_dataset = TestDataset(args)
    query_loader = torch.utils.data.DataLoader(test_dataset,
                              batch_size=1,
                              shuffle=False,
                              num_workers=args.workers,
                              pin_memory=True,
                              drop_last=True)

    labels = get_label_names(args.dataset)


    # Start training.
    sub_epochs = args.steps // args.max_iterations
    logger.info('  Start training ...')

    for epoch in range(10):
        # Train.
        batch_time, data_time, losses, q_loss, align_loss, t_loss = train(train_loader, model, criterion, optimizer,
                                                                          scheduler, epoch, args)

        # Log
        logger.info('============== Epoch [{}] =============='.format(epoch))
        logger.info('Batch time: {:6.3f}'.format(batch_time))
        logger.info('Loading time: {:6.3f}'.format(data_time))
        logger.info('Total Loss  : {:.5f}'.format(losses))
        logger.info('Query Loss  : {:.5f}'.format(q_loss))
        logger.info('Align Loss  : {:.5f}'.format(align_loss))
        logger.info('Threshold Loss  : {:.5f}'.format(t_loss))

        if epoch %1 == 0:
            torch.save(model.state_dict(), args.save_model_path)

#         # Infer.
#         # Get support sample + mask for current class.

#        # Loop over classes.
#         class_dice = {}
#         class_iou = {}
#         for label_val, label_name in labels.items():
#     # Skip BG class.
#           if label_name is 'BG':
#             continue
#           logger.info('*------------------Class: {}--------------------*'.format(label_name))
#           logger.info('*--------------------------------------------------*')

#     # Get support sample + mask for current class.
#           support_sample = test_dataset.getSupport(label=label_val, all_slices=args.all_slices, N=args.n_shot)
#           test_dataset.label = label_val
#     # Infer.
#           with torch.no_grad():
#               scores = infer(model, query_loader, support_sample, args, logger, label_name)

#         # Log class-wise results
#           class_dice[label_name] = torch.tensor(scores.patient_dice).mean().item()
#           class_iou[label_name] = torch.tensor(scores.patient_iou).mean().item()

#           logger.info('Mean class IoU: {}'.format(class_iou[label_name]))
#           logger.info('Mean class Dice: {}'.format(class_dice[label_name]))
#           logger.info('*--------------------------------------------------*')

#     # Save trained model.
#         logger.info('  Saving model ...')
#         torch.save(model.state_dict(), args.save_model_path)

# def infer(model, query_loader, support_sample, args, logger, label_name):

#     # Test mode.
#     model.eval()

#     # Unpack support data.
#     support_image = [support_sample['image'][[i]].float().cuda() for i in range(support_sample['image'].shape[0])]  # n_shot x 3 x H x W
#     support_fg_mask = [support_sample['label'][[i]].float().cuda() for i in range(support_sample['image'].shape[0])]  # n_shot x H x W

#     # Loop through query volumes.
#     scores = Scores()
#     for i, sample in enumerate(query_loader):

#         # Unpack query data.
#         query_image = [sample['image'][i].float().cuda() for i in range(sample['image'].shape[0])]  # [C x 3 x H x W]
#         query_label = sample['label'].long()  # C x H x W
#         query_id = sample['id'][0].split('image_')[1][:-len('.nii.gz')]

#         # Compute output.
#         if args.EP1 is True:
#             # Match support slice and query sub-chunck.
#             query_pred = torch.zeros(query_label.shape[-3:])
#             C_q = sample['image'].shape[1]
#             idx_ = np.linspace(0, C_q, args.n_shot+1).astype('int')
#             for sub_chunck in range(args.n_shot):
#                 support_image_s = [support_image[sub_chunck]]  # 1 x 3 x H x W
#                 support_fg_mask_s = [support_fg_mask[sub_chunck]]  # 1 x H x W
#                 query_image_s = query_image[0][idx_[sub_chunck]:idx_[sub_chunck+1]]  # C' x 3 x H x W
#                 query_pred_s, _, _ = model([support_image_s], [support_fg_mask_s], [query_image_s], train=False)  # C x 2 x H x W
#                 query_pred_s = query_pred_s.argmax(dim=1).cpu()  # C x H x W
#                 query_pred[idx_[sub_chunck]:idx_[sub_chunck+1]] = query_pred_s

#         else:  # EP 2
#             query_pred, _, _ = model([support_image], [support_fg_mask], query_image, train=False)  # C x 2 x H x W
#             query_pred = query_pred.argmax(dim=1).cpu()  # C x H x W

#         # Record scores.
#         scores.record(query_pred, query_label)

#         # Log.
#         logger.info('    Tested query volume: ' + sample['id'][0][len(args.data_root):]
#                     + '. Dice score:  ' + str(scores.patient_dice[-1].item()))

#         # Save predictions.
#         file_name = 'image_' + query_id + '_' + label_name + '.pt'
#         torch.save(query_pred, os.path.join(args.save, file_name))

#     return scores



def train(train_loader, model, criterion, optimizer, scheduler, epoch, args):
    batch_time = AverageMeter('Time', ':6.3f')
    data_time = AverageMeter('Data', ':6.3f')
    losses = AverageMeter('Loss', ':.4f')
    q_loss = AverageMeter('Query loss', ':.4f')
    a_loss = AverageMeter('Align loss', ':.4f')
    t_loss = AverageMeter('Threshold loss', ':.4f')

    # Train mode.
    model.train()

    end = time.time()
    # for i, sample in enumerate(train_loader):
    for i, sample in enumerate(tqdm.tqdm(train_loader)):

        # Extract episode data.
        support_images = [[shot[None].float().cuda() for shot in way]
                          for way in sample['support_images']]
        support_fg_mask = [[shot[None].float().cuda() for shot in way]
                           for way in sample['support_fg_labels']]

        query_images = [query_image.float().cuda() for query_image in sample['query_images']]
        query_labels = torch.cat([query_label.long().cuda() for query_label in sample['query_labels']], dim=0)

        # Log loading time.
        data_time.update(time.time() - end)

        # Compute outputs and losses.
        query_pred, align_loss, thresh_loss = model(support_images, support_fg_mask, query_images,
                                                    train=True, t_loss_scaler=args.t_loss_scaler)

        query_loss = criterion(torch.log(torch.clamp(query_pred, torch.finfo(torch.float32).eps,
                                                     1 - torch.finfo(torch.float32).eps)), query_labels[None])
        loss = query_loss + align_loss + thresh_loss

        # compute gradient and do SGD step
        for param in model.parameters():
            param.grad = None

        loss.backward()
        optimizer.step()
        scheduler.step()

        # Log loss.
        losses.update(loss.item(), query_pred.size(0))
        q_loss.update(query_loss.item(), query_pred.size(0))
        a_loss.update(align_loss.item(), query_pred.size(0))
        t_loss.update(thresh_loss.item(), query_pred.size(0))

        # Log elapsed time.
        batch_time.update(time.time() - end)
        end = time.time()

    return batch_time.avg, data_time.avg, losses.avg, q_loss.avg, a_loss.avg, t_loss.avg


if __name__ == '__main__':
    main()


INFO:root:Namespace(data_root='/content/ADNet-VIN/dataloading/', save_root='/content/ADNet-VIN/log_out', dataset='CHAOST2', n_sv=1, fold=1, max_slices=10, workers=4, steps=15000, n_shot=1, n_query=1, n_way=1, batch_size=1, max_iterations=50, lr=0.001, lr_gamma=0.95, momentum=0.9, weight_decay=0.0005, seed=None, bg_wt=0.1, t_loss_scaler=1.0, min_size=200, all_slices=True, EP1=True)
Namespace(data_root='/content/ADNet-VIN/dataloading/', save_root='/content/ADNet-VIN/log_out', dataset='CHAOST2', n_sv=1, fold=1, max_slices=10, workers=4, steps=15000, n_shot=1, n_query=1, n_way=1, batch_size=1, max_iterations=50, lr=0.001, lr_gamma=0.95, momentum=0.9, weight_decay=0.0005, seed=None, bg_wt=0.1, t_loss_scaler=1.0, min_size=200, all_slices=True, EP1=True)


Loading pre-trained weights!


  pretrained_dict = torch.load('/content/resnext-101-kinetics.pth', map_location='cpu')
INFO:root:  Start training ...
  Start training ...
100%|██████████| 50/50 [03:54<00:00,  4.69s/it]
INFO:root:Batch time:  4.685
Batch time:  4.685
INFO:root:Loading time:  2.217
Loading time:  2.217
INFO:root:Total Loss  : 0.91030
Total Loss  : 0.91030
INFO:root:Query Loss  : 0.74235
Query Loss  : 0.74235
INFO:root:Align Loss  : 0.67075
Align Loss  : 0.67075
INFO:root:Threshold Loss  : -0.50280
Threshold Loss  : -0.50280
100%|██████████| 50/50 [05:14<00:00,  6.28s/it]
INFO:root:Batch time:  6.279
Batch time:  6.279
INFO:root:Loading time:  4.252
Loading time:  4.252
INFO:root:Total Loss  : 0.46326
Total Loss  : 0.46326
INFO:root:Query Loss  : 0.62491
Query Loss  : 0.62491
INFO:root:Align Loss  : 0.34546
Align Loss  : 0.34546
INFO:root:Threshold Loss  : -0.50712
Threshold Loss  : -0.50712
100%|██████████| 50/50 [03:37<00:00,  4.35s/it]
INFO:root:Batch time:  4.348
Batch time:  4.348
INFO:root:Loadin

In [32]:
#!/usr/bin/env python

import argparse
import numpy as np

import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
import torch.utils.data.distributed
from torch.utils.data import DataLoader

# from models.fewshot_anom_3D import FewShotSeg
# from dataloading.datasets_3D import TestDataset
from dataloading.dataset_specifics import *
from utils import *


# def parse_arguments():
#     parser = argparse.ArgumentParser()
#     parser.add_argument('--data_root', type=str, required=True)
#     parser.add_argument('--save_root', type=str, required=True)
#     parser.add_argument('--pretrained_root', type=str, required=True)
#     parser.add_argument('--fold', type=int, required=True)
#     parser.add_argument('--dataset', type=str, required=True)
#     parser.add_argument('--n_shot', default=1, type=int)
#     parser.add_argument('--all_slices', default=True, type=bool)
#     parser.add_argument('--EP1', default=False, type=bool)
#     parser.add_argument('--seed', default=None, type=int)
#     parser.add_argument('--workers', default=0, type=int)

#     return parser.parse_args()


def main():
    args, unknown = parse_arguments()
    # Deterministic setting for reproducability.
    if args.seed is not None:
        random.seed(args.seed)
        torch.manual_seed(args.seed)
        cudnn.deterministic = True

    # Setup the path to save.
    args.save = args.save_root

    logger = set_logger('/content/ADNet-VIN/infer_out/', 'infer.log')
    logger.info(args)

    # Init model and load state_dict.
    model = FewShotSeg(args)
    model = nn.DataParallel(model.cuda())
    # model.load_state_dict(torch.load(args.pretrained_root, map_location="cpu"))
    model.load_state_dict(torch.load('/content/ADNet-VIN/log_out/model.pth', map_location="cpu"))

    # Data loader.
    test_dataset = TestDataset(args)
    query_loader = DataLoader(test_dataset,
                              batch_size=1,
                              shuffle=False,
                              num_workers=args.workers,
                              pin_memory=True,
                              drop_last=True)

    # Inference.
    logger.info('  Start inference ... Note: EP1 is ' + str(args.EP1))
    logger.info('  Support: ' + str(test_dataset.support_dir[len(args.data_root):]))
    logger.info('  Query: ' +
                str([elem[len(args.data_root):] for elem in test_dataset.image_dirs]))

    # Get unique labels (classes).
    labels = get_label_names(args.dataset)

    # Loop over classes.
    class_dice = {}
    class_iou = {}
    for label_val, label_name in labels.items():

        # Skip BG class.
        if label_name is 'BG':
            continue

        logger.info('  *------------------Class: {}--------------------*'.format(label_name))
        logger.info('  *--------------------------------------------------*')

        # Get support sample + mask for current class.
        support_sample = test_dataset.getSupport(label=label_val, all_slices=args.all_slices, N=args.n_shot)
        test_dataset.label = label_val

        # Infer.
        with torch.no_grad():
            scores = infer(model, query_loader, support_sample, args, logger, label_name)

        # Log class-wise results
        class_dice[label_name] = torch.tensor(scores.patient_dice).mean().item()
        class_iou[label_name] = torch.tensor(scores.patient_iou).mean().item()

        logger.info('Mean class IoU: {}'.format(class_iou[label_name]))
        logger.info('Mean class Dice: {}'.format(class_dice[label_name]))
        logger.info('  *--------------------------------------------------*')

    # Log final results.
    logger.info('  *-----------------Final results--------------------*')
    logger.info('  *--------------------------------------------------*')
    logger.info('  Mean IoU: {}'.format(class_iou))
    logger.info('  Mean Dice: {}'.format(class_dice))
    logger.info('  *--------------------------------------------------*')


import torch
import os
import numpy as np
from sklearn.metrics import jaccard_score

def dice_score(pred, target, num_classes):
    dice = []
    for i in range(num_classes):
        pred_i = (pred == i).float()
        target_i = (target == i).float()
        intersection = (pred_i * target_i).sum()
        union = pred_i.sum() + target_i.sum()
        dice.append((2. * intersection / union).item())
    return dice

def iou_score(pred, target, num_classes):
    iou = []
    for i in range(num_classes):
        pred_i = (pred == i).float()
        target_i = (target == i).float()
        intersection = (pred_i * target_i).sum()
        union = pred_i.sum() + target_i.sum() - intersection
        iou.append((intersection / union).item())
    return iou

def infer(model, query_loader, support_sample, args, logger, label_name, num_classes):

    # Test mode.
    model.eval()

    # Unpack support data.
    support_image = [support_sample['image'][i].float().cuda() for i in range(support_sample['image'].shape[0])]  # n_shot x 3 x H x W
    support_fg_mask = [support_sample['label'][[i]].float().cuda() for i in range(support_sample['image'].shape[0])]  # n_shot x H x W

    # Loop through query volumes.
    scores = Scores()
    for i, sample in enumerate(query_loader):

        # Unpack query data.
        query_image = [sample['image'][i].float().cuda() for i in range(sample['image'].shape[0])]  # [C x 3 x H x W]
        query_label = sample['label'].long().cuda()  # C x H x W

        query_id = sample['id'][0].split('image_')[1][:-len('.nii.gz')]

        # Compute output.
        if args.EP1:
            # Match support slice and query sub-chunck.
            query_pred = torch.zeros(query_label.shape[-3:], dtype=torch.long)
            C_q = sample['image'].shape[1]
            idx_ = np.linspace(0, C_q, args.n_shot+1).astype('int')
            for sub_chunck in range(args.n_shot):
                support_image_s = [support_image[sub_chunck]]  # 1 x 3 x H x W
                support_fg_mask_s = [support_fg_mask[sub_chunck]]  # 1 x H x W
                query_image_s = query_image[0][idx_[sub_chunck]:idx_[sub_chunck+1]]  # C' x 3 x H x W
                query_pred_s, _, _ = model([support_image_s], [support_fg_mask_s], [query_image_s], train=False)  # C x num_classes x H x W
                query_pred_s = query_pred_s.argmax(dim=1).cpu()  # C x H x W
                query_pred[idx_[sub_chunck]:idx_[sub_chunck+1]] = query_pred_s

        else:  # EP 2
            query_pred_, _, _ = model([support_image], [support_fg_mask], query_image, train=False)  # C x num_classes x H x W
            query_pred = query_pred_.argmax(dim=1).cpu()  # C x H x W

        # Calculate metrics.
        dice = dice_score(query_pred, query_label, num_classes)
        iou = iou_score(query_pred, query_label, num_classes)

        # Record scores.
        scores.record(query_pred, query_label, dice, iou)

        # Log.
        logger.info('    Tested query volume: ' + sample['id'][0][len(args.data_root):]
                    + '. Dice score:  ' + str(dice) + '. IoU score: ' + str(iou))

        # Save predictions.
        file_name = 'image_' + query_id + '_' + label_name + '.pt'
        torch.save(query_pred, os.path.join(args.save, file_name))

    return scores

  if label_name is 'BG':
INFO:root:Namespace(data_root='/content/ADNet-VIN/dataloading/', save_root='/content/ADNet-VIN/log_out', dataset='CHAOST2', n_sv=1, fold=1, max_slices=10, workers=4, steps=15000, n_shot=1, n_query=1, n_way=1, batch_size=1, max_iterations=50, lr=0.001, lr_gamma=0.95, momentum=0.9, weight_decay=0.0005, seed=None, bg_wt=0.1, t_loss_scaler=1.0, min_size=200, all_slices=True, EP1=True, save='/content/ADNet-VIN/log_out')
Namespace(data_root='/content/ADNet-VIN/dataloading/', save_root='/content/ADNet-VIN/log_out', dataset='CHAOST2', n_sv=1, fold=1, max_slices=10, workers=4, steps=15000, n_shot=1, n_query=1, n_way=1, batch_size=1, max_iterations=50, lr=0.001, lr_gamma=0.95, momentum=0.9, weight_decay=0.0005, seed=None, bg_wt=0.1, t_loss_scaler=1.0, min_size=200, all_slices=True, EP1=True, save='/content/ADNet-VIN/log_out')
Namespace(data_root='/content/ADNet-VIN/dataloading/', save_root='/content/ADNet-VIN/log_out', dataset='CHAOST2', n_sv=1, fold=1, max_slices=10, wo

Loading pre-trained weights!


  pretrained_dict = torch.load('/content/resnext-101-kinetics.pth', map_location='cpu')
  model.load_state_dict(torch.load('/content/ADNet-VIN/log_out/model.pth', map_location="cpu"))
INFO:root:  Start inference ... Note: EP1 is True
  Start inference ... Note: EP1 is True
  Start inference ... Note: EP1 is True
INFO:root:  Support: chaos_MR_T2_normalized/image_19.nii.gz
  Support: chaos_MR_T2_normalized/image_19.nii.gz
  Support: chaos_MR_T2_normalized/image_19.nii.gz
INFO:root:  Query: ['chaos_MR_T2_normalized/image_8.nii.gz', 'chaos_MR_T2_normalized/image_10.nii.gz', 'chaos_MR_T2_normalized/image_13.nii.gz', 'chaos_MR_T2_normalized/image_15.nii.gz']
  Query: ['chaos_MR_T2_normalized/image_8.nii.gz', 'chaos_MR_T2_normalized/image_10.nii.gz', 'chaos_MR_T2_normalized/image_13.nii.gz', 'chaos_MR_T2_normalized/image_15.nii.gz']
  Query: ['chaos_MR_T2_normalized/image_8.nii.gz', 'chaos_MR_T2_normalized/image_10.nii.gz', 'chaos_MR_T2_normalized/image_13.nii.gz', 'chaos_MR_T2_normalized/ima

IndexError: Caught IndexError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/_utils/worker.py", line 309, in _worker_loop
    data = fetcher.fetch(index)  # type: ignore[possibly-undefined]
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/_utils/fetch.py", line 52, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/_utils/fetch.py", line 52, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/content/ADNet-VIN/dataloading/datasets_3D.py", line 58, in __getitem__
    sample['image'] = torch.from_numpy(img[idx])
IndexError: boolean index did not match indexed array along dimension 0; dimension is 1 but corresponding boolean dimension is 34
