In [1]:
import cv2,os
from skimage import io
from PIL import Image
import numpy as np
from matplotlib import pyplot as plt
from tqdm import tqdm_notebook as tqdm

from sklearn.model_selection import KFold
from sklearn.metrics import roc_auc_score,precision_score,accuracy_score,roc_curve

import torch
from torch.utils.data import Dataset,TensorDataset,random_split,DataLoader,SubsetRandomSampler
from torch.utils.data.dataset import Subset
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
import torchvision.models as models
import sys

if not sys.warnoptions:
    import warnings
    warnings.simplefilter("ignore")

  warn(


In [2]:
if torch.cuda.is_available():
    device = torch.device("cuda")
    print("torch.device(cuda)")
    print("torch.cuda.device_count(): ", torch.cuda.device_count())
    for i in range(torch.cuda.device_count()):
        print(torch.cuda.get_device_name())
    print("torch.cuda.current_device()", torch.cuda.current_device())
else:
    device = torch.device("cpu")
    print("torch.device(cpu)")

torch.device(cuda)
torch.cuda.device_count():  4
Tesla V100-SXM2-16GB
Tesla V100-SXM2-16GB
Tesla V100-SXM2-16GB
Tesla V100-SXM2-16GB
torch.cuda.current_device() 0


# 1. Read data

In [3]:
ls Datasets/CTRL*

Datasets/CTRL_All.npy   Datasets/CTRL_Dapi.npy
Datasets/CTRL_CTCF.npy  Datasets/CTRL_H3K27ac.npy


In [4]:
ls Datasets/RETT*

Datasets/RETT_HPS3042_All.npy      Datasets/RETT_HPS3084_All.npy
Datasets/RETT_HPS3042_CTCF.npy     Datasets/RETT_HPS3084_CTCF.npy
Datasets/RETT_HPS3042_Dapi.npy     Datasets/RETT_HPS3084_Dapi.npy
Datasets/RETT_HPS3042_H3K27ac.npy  Datasets/RETT_HPS3084_H3K27ac.npy
Datasets/RETT_HPS3049_All.npy      Datasets/RETT_HPS9999_All.npy
Datasets/RETT_HPS3049_CTCF.npy     Datasets/RETT_HPS9999_CTCF.npy
Datasets/RETT_HPS3049_Dapi.npy     Datasets/RETT_HPS9999_Dapi.npy
Datasets/RETT_HPS3049_H3K27ac.npy  Datasets/RETT_HPS9999_H3K27ac.npy


In [5]:
stain_type = "H3K27ac"
rett_type  = "HPS3042"

X_Ctrl = np.load(f"./Datasets/CTRL_{stain_type}.npy",allow_pickle=True)
X_Rett = np.load(f"./Datasets/RETT_{rett_type}_{stain_type}.npy",allow_pickle=True)
y_Ctrl = torch.zeros(len(X_Ctrl), dtype=torch.int64)
y_Rett = torch.ones(len(X_Rett), dtype=torch.int64)
X = np.concatenate((X_Ctrl, X_Rett), axis = 0)
y = torch.cat((y_Ctrl, y_Rett), 0)

# 2. Data processing

In [6]:
class cell_dataset(Dataset):
    def __init__(self, x, y):
        self.x = x
        self.y = y
        self.transform = transforms.ToTensor()
    def __len__(self):
        return len(self.x)
    def __getitem__(self, idx):
        return self.transform(self.x[idx]).to(torch.float), F.one_hot(self.y[idx],num_classes=2).to(torch.float)

dataset = cell_dataset(X, y)

In [7]:
batch_size = 64
train_size = int(len(X)*0.8)
valid_size = len(X) - train_size

# train_data, valid_data = random_split(dataset=dataset, lengths=[train_size, valid_size], 
#                                       generator=torch.Generator().manual_seed(42))
# dataloader_train = DataLoader(train_data, batch_size=batch_size, shuffle=True)
# dataloader_valid = DataLoader(valid_data, batch_size=batch_size, shuffle=True)
dataloader_valid = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# 3. ResNet model

In [8]:
model_type="Resnet10_noavg"
rett_type_test = "HPS3042"
homepath="/groups/4/gaa50089/acd13264yb/Rettsyndrome/Classification"
modelpath=f"{homepath}/results/{rett_type_test}_{stain_type}_{model_type}/{rett_type_test}_{stain_type}_{model_type}_Fold0.pkl"
weight = torch.load(modelpath)
print(modelpath)

/groups/4/gaa50089/acd13264yb/Rettsyndrome/Classification/results/HPS3042_H3K27ac_Resnet10_noavg/HPS3042_H3K27ac_Resnet10_noavg_Fold0.pkl


In [27]:
if model_type=="Resnet10_noavg":
    from models.Resnet10_noavg import MyModel
elif model_type=="Resnet10_noavg":
    from models.Resnet10 import MyModel
elif model_type=="Resnet18":
    from models.Resnet18 import MyModel
    
model = MyModel().to(device)

In [28]:
ngpu = 1
if (device.type == 'cuda') and (ngpu > 1):
    model = nn.DataParallel(model, list(range(ngpu)))

In [29]:
image_size = 500
test_input = torch.ones(1,3,image_size,image_size).to(device)
output = model(test_input)
print(output.size())
print(output)
print(nn.Softmax(dim=1)(output))
print(output.argmax(1))

torch.Size([1, 2])
tensor([[-0.0843,  0.0411]], device='cuda:0', grad_fn=<AddmmBackward0>)
tensor([[0.4687, 0.5313]], device='cuda:0', grad_fn=<SoftmaxBackward0>)
tensor([1], device='cuda:0')


# 4. Train

In [12]:
model.avgpool = nn.AdaptiveAvgPool2d(1)

# loss_function = nn.BCELoss()
weights = torch.tensor([(len(X_Ctrl)+len(X_Rett))/len(X_Ctrl), 
                        (len(X_Ctrl)+len(X_Rett))/len(X_Rett)]).cuda()
# loss_function = nn.CrossEntropyLoss(weight=weights)
loss_function = nn.BCEWithLogitsLoss(pos_weight=weights)

optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.99)

def train(model,device,dataloader_train,loss_function,optimizer):
    losses_train = []
    n_train = 0
    acc_train = 0
    optimizer.step()
    model.train()
    for x, y in dataloader_train:
        n_train += y.size()[0]
        model.zero_grad()  # 勾配の初期化
        x = x.to(device)  # テンソルをGPUに移動
        y = y.to(device)
        output = model(x)  # 順伝播
        loss = loss_function(output, y)  # 誤差(クロスエントロピー誤差関数)の計算
        loss.backward()  # 誤差の逆伝播
        optimizer.step()  # パラメータの更新
        acc_train += (output.argmax(1) == y[:,1]).float().sum().item()
        losses_train.append(loss.tolist())
    return np.mean(losses_train), (acc_train/n_train)
        
def valid(model,device,dataloader_valid,loss_function):
    losses_valid = []
    n_val = 0
    acc_val = 0
    model.eval()
    for x, y in dataloader_valid:
        n_val += y.size()[0]
        x = x.to(device)  # テンソルをGPUに移動
        y = y.to(device)
        output = model(x)  # 順伝播
        loss = loss_function(output, y)  # 誤差(クロスエントロピー誤差関数)の計算
        acc_val += (output.argmax(1) == y[:,1]).float().sum().item()
        losses_valid.append(loss.tolist())
    return np.mean(losses_valid), (acc_val/n_val)

history = {'loss_train': [], 'loss_valid': [],'acc_train':[],'acc_valid':[]}

In [14]:
# n_epochs = 10
# for epoch in range(n_epochs):
#     loss_train, acc_train = train(model,device,dataloader_train,loss_function,optimizer)
#     loss_valid, acc_valid = valid(model,device,dataloader_valid,loss_function)
#     scheduler.step()
    
#     history['loss_train'].append(loss_train)
#     history['loss_valid'].append(loss_valid)
#     history['acc_train'].append(acc_train)
#     history['acc_valid'].append(acc_valid)
#     print('EPOCH: {}, Train [Loss: {:.3f}, Accuracy: {:.3f}], Valid [Loss: {:.3f}, Accuracy: {:.3f}]'
#           .format(epoch, loss_train, acc_train, loss_valid, acc_valid))

In [15]:
# # train processing plot
# n_epochs = 50
# epochs=range(1,n_epochs+1)
# plt.ylim(0,1.0)
# plt.plot(epochs, history['acc_train'], 'b', label='Training accuracy')  
# plt.plot(epochs, history['acc_valid'], 'r', label='Validation accuracy')
# plt.title('Training and Validation accuracy')
# plt.legend()
# plt.figure()
# plt.show()

# 4. Validate data

In [36]:
def plot_and_save_roc_curve(y_true, y_scores):
    fpr, tpr, _ = roc_curve(y_true, y_scores)
    auc_score = roc_auc_score(y_true, y_scores)
    return auc_score

def valid(model, device, dataloader_valid):
    model.eval()
    y_true = []
    y_scores = []
    acc_val = 0
    n_val = 0
    for x, y in dataloader_valid:
        n_val += y.size()[0]
        x, y = x.to(device), y.to(device)
        with torch.no_grad():
            output = model(x)
        y_true.extend(y[:,1].tolist())  # 假设y的第二列是标签
        y_scores.extend(output[:,1].sigmoid().tolist())  # 假设模型的第二个输出是预测为正类的得分
        acc_val += (output.argmax(1) == y[:,1]).float().sum().item()
    auc_score = plot_and_save_roc_curve(y_true, y_scores)  # 调用ROC绘图函数
    return acc_val / n_val, auc_score

def loaddata(stain_type, rett_type):
    X_Ctrl = np.load(f"{homepath}/Datasets_LR/CTRL_{stain_type}.npy",allow_pickle=True)
    X_Rett = np.load(f"{homepath}/Datasets_LR/RETT_{rett_type}_{stain_type}.npy",allow_pickle=True)
    y_Ctrl = torch.zeros(len(X_Ctrl), dtype=torch.int64)
    y_Rett = torch.ones(len(X_Rett), dtype=torch.int64)
    X = np.concatenate((X_Ctrl, X_Rett), axis = 0)
    y = torch.cat((y_Ctrl, y_Rett), 0)
    dataset = cell_dataset(X, y)
    return dataset

In [37]:
stain_type = "H3K27ac"
rett_type  = "HPS3042"
model_type="Resnet10_noavg"
rett_type_model = "HPS3042"
homepath="/groups/4/gaa50089/acd13264yb/Rettsyndrome/Classification"

stain_list = ["All", "H3K27ac", "CTCF", "Dapi"]
rett_list = ["HPS3042", "HPS3049", "HPS3084"]
for stain_type in stain_list:
    print(f"{stain_type}")

    for rett_type in rett_list:
        for rett_type_model in rett_list:
            dataset = loaddata(stain_type, rett_type)

            n_splits=5
            splits=KFold(n_splits,shuffle=True,random_state=42)
            history = {'acc_valid':[], 'auc_valid':[]}
            
            for fold, (train_idx, val_idx) in enumerate(splits.split(np.arange(len(dataset)))):
#                 if fold != 0: break
                print(f"Data {rett_type}, Model {rett_type_model}_Fold{fold}", end='  ')
                valid_sampler = SubsetRandomSampler(val_idx)
                dataloader_valid = DataLoader(dataset, batch_size=batch_size, sampler=valid_sampler)

                modelpath=f"{homepath}/results_LR/{rett_type_model}_{stain_type}_{model_type}/{rett_type_model}_{stain_type}_{model_type}_Fold{fold}.pkl"
                weight = torch.load(modelpath)

                model = MyModel().to(device)
                model.resnet.load_state_dict(weight)
                model.avgpool = nn.AdaptiveAvgPool2d(1)

                acc_valid, auc_valid = valid(model, device, dataloader_valid)
                print(f'Accuracy: {acc_valid:.3f}, AUC: {auc_valid:.3f}')
                history['acc_valid'].append(acc_valid)
                history['auc_valid'].append(auc_valid)
            print(np.average(history['acc_valid']), np.average(history['auc_valid']))
            print("")

All
Data HPS3042, Model HPS3042_Fold0  Accuracy: 0.975, AUC: 0.996
Data HPS3042, Model HPS3042_Fold1  Accuracy: 0.990, AUC: 0.999
Data HPS3042, Model HPS3042_Fold2  Accuracy: 0.975, AUC: 0.997
Data HPS3042, Model HPS3042_Fold3  Accuracy: 0.979, AUC: 0.997
Data HPS3042, Model HPS3042_Fold4  Accuracy: 0.983, AUC: 0.998
0.9805938801376064 0.9975356603621094

Data HPS3042, Model HPS3049_Fold0  Accuracy: 0.842, AUC: 0.931
Data HPS3042, Model HPS3049_Fold1  Accuracy: 0.850, AUC: 0.947
Data HPS3042, Model HPS3049_Fold2  Accuracy: 0.816, AUC: 0.954
Data HPS3042, Model HPS3049_Fold3  Accuracy: 0.762, AUC: 0.868
Data HPS3042, Model HPS3049_Fold4  Accuracy: 0.804, AUC: 0.890
0.8146637696903856 0.9180023682368604

Data HPS3042, Model HPS3084_Fold0  Accuracy: 0.911, AUC: 0.979
Data HPS3042, Model HPS3084_Fold1  Accuracy: 0.926, AUC: 0.987
Data HPS3042, Model HPS3084_Fold2  Accuracy: 0.894, AUC: 0.970
Data HPS3042, Model HPS3084_Fold3  Accuracy: 0.907, AUC: 0.978
Data HPS3042, Model HPS3084_Fold4  A

Data HPS3049, Model HPS3084_Fold1  Accuracy: 0.563, AUC: 0.527
Data HPS3049, Model HPS3084_Fold2  Accuracy: 0.555, AUC: 0.460
Data HPS3049, Model HPS3084_Fold3  Accuracy: 0.570, AUC: 0.481
Data HPS3049, Model HPS3084_Fold4  Accuracy: 0.570, AUC: 0.438
0.5673446083899892 0.4879134613313057

Data HPS3084, Model HPS3042_Fold0  Accuracy: 0.953, AUC: 0.990
Data HPS3084, Model HPS3042_Fold1  Accuracy: 0.959, AUC: 0.990
Data HPS3084, Model HPS3042_Fold2  Accuracy: 0.953, AUC: 0.991
Data HPS3084, Model HPS3042_Fold3  Accuracy: 0.954, AUC: 0.991
Data HPS3084, Model HPS3042_Fold4  Accuracy: 0.918, AUC: 0.975
0.9472995090016367 0.9873241489554415

Data HPS3084, Model HPS3049_Fold0  Accuracy: 0.643, AUC: 0.700
Data HPS3084, Model HPS3049_Fold1  Accuracy: 0.643, AUC: 0.678
Data HPS3084, Model HPS3049_Fold2  Accuracy: 0.610, AUC: 0.627
Data HPS3084, Model HPS3049_Fold3  Accuracy: 0.632, AUC: 0.627
Data HPS3084, Model HPS3049_Fold4  Accuracy: 0.607, AUC: 0.578
0.6271685761047464 0.6419179889527351

D

# 99. Save model

In [24]:
for param in model.parameters():
    param.requires_grad = True
torch.save(model.module.resnet.state_dict(),"Models/Resnet_H3K27ac.pkl")

In [81]:
class ResNet(nn.Module):
        def __init__(self):
            super(ResNet,self).__init__()
            self.resnet = models.resnet18(weights=True)
            self.resnet.layer3 = nn.Sequential()
            self.resnet.layer4 = nn.Sequential()
            self.resnet.avgpool = nn.Sequential()
            self.resnet.fc = nn.Linear(128*75*75, 2)
            self.resnet.load_state_dict(weight)
        def forward(self, x):
            x = self.resnet(x)
            x = nn.Softmax(dim=1)(x)
            return x
model = ResNet()