In [1]:
import torchvision.models as models
import dsmil as mil

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.autograd import Variable

from torch.utils.tensorboard import SummaryWriter

import sys
import argparse
import os
import glob
import pandas as pd
import csv
import numpy as np
import random
import math
from skimage import io, img_as_float
from sklearn.utils import shuffle
from sklearn.metrics import roc_curve, roc_auc_score, accuracy_score
from PIL import Image
from collections import OrderedDict
import matplotlib.pyplot as plt
from collections import OrderedDict

import torchvision.transforms.functional as VF
from torchvision import transforms, utils

In [2]:
args = argparse.Namespace()
args.num_classes = 2
args.num_feats = 512
args.num_epochs = 90
args.batch_size = 128
args.num_workers = 0
args.top_k = 8
args.lr = 0.0001
args.patch_size = 224
args.img_channel = 3
args.dataset = 'wsi-tcga-lung'
args.backbone = 'resnet18'
args.magnification = '20x'

In [3]:
if args.backbone == 'resnet18':
    resnet = models.resnet18(pretrained=False, norm_layer=nn.InstanceNorm2d)
    num_feats = 512
if args.backbone == 'resnet34':
    resnet = models.resnet34(pretrained=False, norm_layer=nn.InstanceNorm2d)
    num_feats = 512
if args.backbone == 'resnet50':
    resnet = models.resnet50(pretrained=False, norm_layer=nn.InstanceNorm2d)
    num_feats = 2048
if args.backbone == 'resnet101':
    resnet = models.resnet101(pretrained=False, norm_layer=nn.InstanceNorm2d)
    num_feats = 2048
for param in resnet.parameters():
    param.requires_grad = False
resnet.fc = nn.Identity()
i_classifier = mil.IClassifier(resnet, num_feats, output_class=args.num_classes).cuda()

In [4]:
weight_path = glob.glob('simclr/runs/*/checkpoints/*.pth')[-1]
state_dict_weights = torch.load(weight_path)
try:
    state_dict_weights.pop('module.l1.weight')
    state_dict_weights.pop('module.l1.bias')
    state_dict_weights.pop('module.l2.weight')
    state_dict_weights.pop('module.l2.bias')
except:
    state_dict_weights.pop('l1.weight')
    state_dict_weights.pop('l1.bias')
    state_dict_weights.pop('l2.weight')
    state_dict_weights.pop('l2.bias')
state_dict_init = i_classifier.state_dict()
new_state_dict = OrderedDict()
for (k, v), (k_0, v_0) in zip(state_dict_weights.items(), state_dict_init.items()):
    name = k_0
    new_state_dict[name] = v
i_classifier.load_state_dict(new_state_dict, strict=False)

_IncompatibleKeys(missing_keys=['fc.weight', 'fc.bias'], unexpected_keys=[])

In [15]:
class BagDataset():
    def __init__(self, csv_file, transform=None):
        self.files_list = csv_file
        self.transform = transform
    def __len__(self):
        return len(self.files_list)
    def __getitem__(self, idx):
        temp_path = self.files_list[idx]
        img = os.path.join(temp_path)
        img = Image.open(img)
        sample = {'input': img}
        
        if self.transform:
            sample = self.transform(sample)
        return sample 

class ToTensor(object):
    def __call__(self, sample):
        img = sample['input']
        img = VF.to_tensor(img)
        return {'input': img} 
    
class Compose(object):
    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, img):
        for t in self.transforms:
            img = t(img)
        return img

def bag_dataset(csv_file_path):
    transformed_dataset = BagDataset(csv_file=csv_file_path,
                                    transform=Compose([
                                        ToTensor()
                                    ]))
    dataloader = DataLoader(transformed_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, drop_last=False)
    return dataloader, len(transformed_dataset)

def compute_feats(bags_list, save_path=None):
    num_bags = len(bags_list)
    Tensor = torch.FloatTensor
    for i in range(0, num_bags):
        feats_list = []
        if args.magnification == '20x':
            csv_file_path = glob.glob(os.path.join(bags_list[i], '*/*.jpg'))
        else:
            csv_file_path = glob.glob(os.path.join(bags_list[i], '*.jpg'))
        print(len(csv_file_path))
        dataloader, bag_size = bag_dataset(csv_file_path)
        with torch.no_grad():
            for iteration, batch in enumerate(dataloader):
                patches = batch['input'].float().cuda() 
                feats, classes = i_classifier(patches)
                feats = feats.cpu().numpy()
                feats_list.extend(feats)
        df = pd.DataFrame(feats_list)
        os.makedirs(os.path.join(save_path, bags_list[i].split(os.path.sep)[-3]), exist_ok=True)
        df.to_csv(os.path.join(save_path, bags_list[i].split(os.path.sep)[-3], bags_list[i].split(os.path.sep)[-2]+'.csv'), index=False, float_format='%.4f')
        print(i)

In [16]:
if args.dataset == 'wsi-tcga-lung':
    bags_path = os.path.join('WSI', 'TCGA-lung', 'pyramid', '*', '*')
feats_path = os.path.join('datasets', args.dataset)
os.makedirs(feats_path, exist_ok=True)
bags_list = glob.glob(bags_path+os.path.sep)
compute_feats(bags_list, feats_path)

5667
0
6181
1
5991
2
25254
3
9857
4
