In [1]:
import numpy as np
import tifffile
import os
import pandas as pd
import glob
import random
import torch
import torch.nn as nn
import torchvision.models as models
import torch.optim as optim

In [2]:
def same_seeds(seed):
    # Python built-in random module
    random.seed(seed)
    # Numpy
    np.random.seed(seed)
    # Torch
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

same_seeds(5)

In [3]:
def set_device():
    if torch.cuda.is_available():
        device = torch.device('cuda:1')
    else:
        device = torch.device('cpu')
        
    print(device)

    return device


device = set_device()

cuda:1


In [50]:
def test_transform(img):
    x = np.zeros(img.shape,dtype=np.float32)
    for i in range(img.shape[0]):
        if img[i].min() >= 0:
            img[i]/=255
        else:
            img[i] = (img[i] - img[i].min())/(img[i].max()-img[i].min())
        
        m = img[i].mean()
        s = img[i].std()
        x[i] = (img[i]-m)/s
    
    x = torch.tensor(x,dtype=torch.float32)

    return x

class test_cycledataset():
    def __init__(self,filelist,transform):
        self.filelist = filelist
        self.transform = transform

        
    def __getitem__(self,index):
        if isinstance(self.filelist,list):
            # order of channel: (nucleus, mitochondria, bright-field)
            #                   (nucleus, bright-field, predicted mitochondria)
            # It is necessary to follow the above order.
                         
            img = tifffile.imread(self.filelist[index])
            # img channel: (nucleus,mitochondria,bright-field,predicted mito)
            
#             img = img[0:3]  #(nucleus, mitochondria, bright-field)
            img = np.concatenate((img[0:2],np.expand_dims(img[3],axis=0)),axis=0)
            # (nucleus, bright-field, predicted mitochondria)
            
            img = self.transform(img)
            return img.to(device)

        else:
            img = tifffile.imread(self.filelist.iloc[index,0])
            # img channel: (nucleus,mitochondria,bright-field,predicted mito)

#             img = img[0:3]  #(nucleus, mitochondria, bright-field)
            img = np.concatenate((img[0:2],np.expand_dims(img[3],axis=0)),axis=0)
            # (nucleus, bright-field, predicted mitochondria)
            
            img = self.transform(img)
            return img.to(device)
    
    def getbatch(self,indices):
        filepath = []
        for index in indices:
            filepath.append(filelist[index])
        return filepath
                        
    def __len__(self):
        if isinstance(self.filelist,list): 
            return len(self.filelist)
        else: 
            return self.filelist.shape[0]

In [51]:


#test = glob.glob('path to tiff files')
test = pd.read_csv(r"path to csv file of test data set")
print(f'number of test set:{len(test)}')
testdataset = test_cycledataset(test, test_transform)

from torch.utils.data import DataLoader

bs = 64

testdataloader = DataLoader(testdataset,batch_size=bs)

number of test set:5


In [52]:
classifier = models.resnet34()

In [53]:
classifier.fc = torch.nn.Sequential(
    torch.nn.Linear(512,2,bias=True)
)

In [54]:
# remember to change the folder name

save_dir = r"path to the folder"
if not os.path.exists(save_dir):
    os.makedirs(save_dir)

In [55]:
classifier.load_state_dict(torch.load('path to the trained model',map_location=device))
classifier.to(device)
classifier.eval()

y_pred = torch.tensor([])

for i, x in enumerate(testdataloader):
    with torch.no_grad():
        yhat = classifier(x)
        p = torch.max(yhat,1)[1]  # return the index of maximum (max prob after sofmax)
        y_pred = torch.cat([y_pred,p.cpu()])

In [56]:
y_pred

tensor([0., 0., 0., 0., 0.])

In [57]:
if isinstance(test,list):
    df = pd.DataFrame({'path':test,'prediction':y_pred})
else:
    df = pd.DataFrame({'path':test['path'],'prediction':y_pred})

In [58]:
df.to_csv(save_dir+'pdmito_label_prediction.csv',index=False)