In [5]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"
import json
from random import shuffle, seed
import numpy as np
import torch
import torchvision.models as models
from tqdm import tqdm
from torchvision import transforms as trn
from torchnet import meter
from torch.optim.lr_scheduler import StepLR
from captioning.utils.resnet_utils import myResnet
import captioning.utils.resnet as resnet
from torch.utils.data import DataLoader,Dataset
from PIL import Image
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.modules.loss import _WeightedLoss
import copy
import torch
import torch.nn as nn
import torchvision.models.resnet
from torchvision.models.resnet import BasicBlock, Bottleneck

import rdkit
from rdkit import Chem
from rdkit.Chem import  Draw

# data prepared

In [2]:
from chembl_webresource_client.new_client import new_client
molecule = new_client.molecule
approved_drugs0 = molecule.filter(max_phase=0)
approved_drugs1 = molecule.filter(max_phase=1)
approved_drugs2 = molecule.filter(max_phase=2)
approved_drugs3 = molecule.filter(max_phase=3)
approved_drugs4 = molecule.filter(max_phase=4)

In [1]:
smiles0 = []
smiles1 = []
smiles2 = []
smiles3 = []
smiles4 = []
for i,mol in enumerate(tqdm(approved_drugs0)):
    try:
        smiles0.append(mol['molecule_structures']['canonical_smiles'])
    except:
        continue
#     if i == 3000:
#         break
for i,mol in enumerate(tqdm(approved_drugs1)):
    try:
        smiles1.append(mol['molecule_structures']['canonical_smiles'])
    except:
        continue
for i,mol in enumerate(tqdm(approved_drugs2)):
    try:
        smiles2.append(mol['molecule_structures']['canonical_smiles'])
    except:
        continue
for i,mol in enumerate(tqdm(approved_drugs3)):
    try:
        smiles3.append(mol['molecule_structures']['canonical_smiles'])
    except:
        continue
for i,mol in enumerate(tqdm(approved_drugs4)):
    try:
        smiles4.append(mol['molecule_structures']['canonical_smiles'])
    except:
        continue

In [2]:
for i,smile in enumerate(tqdm(smiles0)):
    mol = Chem.MolFromSmiles(smile)
    Draw.MolToFile(mol,'mols/0/'+str(i)+'.png')
for i,smile in enumerate(tqdm(smiles1)):
    mol = Chem.MolFromSmiles(smile)
    Draw.MolToFile(mol,'mols/1/'+str(i)+'.png')
for i,smile in enumerate(tqdm(smiles2)):
    mol = Chem.MolFromSmiles(smile)
    Draw.MolToFile(mol,'mols/2/'+str(i)+'.png')
for i,smile in enumerate(tqdm(smiles3)):
    mol = Chem.MolFromSmiles(smile)
    Draw.MolToFile(mol,'mols/3/'+str(i)+'.png')
for i,smile in enumerate(tqdm(smiles4)):
    mol = Chem.MolFromSmiles(smile)
    Draw.MolToFile(mol,'mols/4/'+str(i)+'.png')

In [3]:
li = []
di = {}
for image in tqdm(os.listdir('mols/0')):
    di['image'] = 'mols/0/'+str(image)
    di['label'] = '0'
    li.append(copy.deepcopy(di))
for image in tqdm(os.listdir('mols/1')):
    di['image'] = 'mols/1/'+image
    di['label'] = '1'
    li.append(copy.deepcopy(di))

for image in tqdm(os.listdir('mols/2')):
    di['image'] = 'mols/2/'+image
    di['label'] = '2'
    li.append(copy.deepcopy(di))


for image in tqdm(os.listdir('mols/3')):
    di['image'] = 'mols/3/'+image
    di['label'] = '3'
    li.append(copy.deepcopy(di))


for image in tqdm(os.listdir('mols/4')):
    di['image'] = 'mols/4/'+image
    di['label'] = '4'
    li.append(copy.deepcopy(di))

with open('mols/mols.json','w') as f:
    json.dump(li,f)

# dataloader

In [63]:

class molsDataset(Dataset):
    def __init__(self, results, transform=None):
        """
        :param data_dir: str, 数据集所在路径
        :param transform: torch.transform，数据预处理
        """
#         self.label_name = {'0':0,'1':1,'2':2,'3':3,'4':4}
        self.label_name = {'0':0,'4':1}
        self.data_info = self.get_img_info(results)  # data_info存储所有图片路径和标签，在DataLoader中通过index读取样本
        self.transform = transform


    def __getitem__(self, index):  # 函数功能是根据index索引去返回图片img以及标签label
        path_img, label = self.data_info[index]
        img = Image.open(path_img).convert('RGB')     # 0~255

        if self.transform is not None:
            img = self.transform(img)   # 在这里做transform，转为tensor等等


        return img, label
    def __len__(self):   # 函数功能是用来查看数据的长度，也就是样本的数量
        return len(self.data_info)

    @staticmethod
    def get_img_info(results):   # 函数功能是用来获取数据的路径以及标签
        data_info = list()
        for result in results:
            # 遍历类别
                path_img = result['image']
                label = result['label']
                data_info.append((path_img, int(label)))
        return data_info

In [64]:
preprocess = trn.Compose([
                trn.ToTensor(),
                trn.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
with open('mols/mols.json','r') as f:
    datas = json.load(f)
train_data = molsDataset(datas,transform=preprocess) 
valid_data = molsDataset(datas,transform=preprocess) 

In [66]:
train_loader = DataLoader(dataset=train_data, batch_size=16, shuffle=True) # shuffle=True，每一个epoch中样本都是乱序的
valid_loader = DataLoader(dataset=valid_data, batch_size=16) 

# Resnet model

In [51]:
class ResNet(torchvision.models.resnet.ResNet):
    def __init__(self, block, layers, num_classes):
        super(ResNet, self).__init__(block, layers, num_classes)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=0, ceil_mode=True) # change
        for i in range(2, 5):
            getattr(self, 'layer%d'%i)[0].conv1.stride = (2,2)
            getattr(self, 'layer%d'%i)[0].conv2.stride = (1,1)

def resnet101(pretrained=False):
    """Constructs a ResNet-101 model.

    Args:
        pretrained (bool): If True, returns a model pre-trained on ImageNet
    """
    model = ResNet(Bottleneck, [3, 4, 23, 3])
    if pretrained:
        model.load_state_dict(model_zoo.load_url(model_urls['resnet101']))
    return model



In [56]:
class myResnet(nn.Module):
    def __init__(self, resnet,num_classes):
        super(myResnet, self).__init__()
        self.resnet = nn.Sequential(*list(resnet.children())[:-1])
        self.fc = nn.Linear(in_features=2048, out_features=num_classes)


    def forward(self, x):
        x = self.resnet(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x



In [60]:
net = getattr(resnet, 'resnet101')()
net.load_state_dict(torch.load(os.path.join('./data/imagenet_weights','resnet101'+'.pth')))
my_resnet = myResnet(net,num_classes=2)
my_resnet.cuda()
device_ids = [0, 1,2,3]
net = torch.nn.DataParallel(my_resnet, device_ids=device_ids)


# Focal loss

In [7]:
class FocalSmoothedLoss(_WeightedLoss):
    def __init__(self, gamma=2, weight=None, reduction='mean'):
        super().__init__(weight=weight, reduction=reduction)
        self.weight = weight
        self.reduction = reduction
        self.gamma = gamma

    def forward(self, inputs, targets):
        logpt = F.log_softmax(inputs, dim=1)
        pt = torch.exp(logpt)
        logpt = (1-pt)**self.gamma * logpt
#         lsm = F.nll_loss(logpt, targets, self.weight)

        if self.weight is not None:
            lsm = lsm * self.weight.unsqueeze(0)

#         loss = -(targets * lsm).sum(-1)
        loss = F.nll_loss(logpt, targets, self.weight)

        if  self.reduction == 'sum':
            loss = loss.sum()
        elif  self.reduction == 'mean':
            loss = loss.mean()

        return loss

# optimizer

In [67]:
criterion = nn.CrossEntropyLoss()
# criterion = FocalSmoothedLoss()
optimizer = torch.optim.Adam(net.parameters(), lr=5e-5, betas=(0.9, 0.999), eps=1e-08, weight_decay=0)
scheduler = StepLR(optimizer, step_size=10, gamma=0.1)

# Validation

In [68]:
def val():
    net.eval()
    total_val = 0
    correct_val = 0
    loss_val = 0
    with torch.no_grad():
        t = tqdm(valid_loader)
        for data in t:
            inputs,labels = data[0].cuda(),data[1].cuda()
            outputs = net(inputs)
            loss = criterion(outputs, labels)
            _, predicted = torch.max(outputs.data, 1)
            print(predicted)
            total_val += labels.size(0)
            correct_val += (predicted == labels).cpu().squeeze().sum().numpy()
            loss_val += loss.item()
            t.set_description('acc: %s  Loss: %s' %(str(correct_val/total_val),loss.item()))
        print('acc: %s  Loss: %s' %(str(correct_val/total_val),loss_val/len(t)))
    return correct_val/total_val,loss_val/len(t)

# Train

In [4]:
net.train()
for i in range(50):
    t = tqdm(train_loader)
    for data in t:
        inputs,labels = data[0].cuda(),data[1].cuda()
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs,labels)
        loss.backward()
        optimizer.step()
        t.set_description('epoch: %s  Loss: %s' %(str(i),loss.item()))
    scheduler.step()
    print('epoch: %s  Loss: %s' %(str(i),loss.item()))
    acc,mean_loss = val()
    state = {'net':net.state_dict(), 'optimizer':optimizer.state_dict(), 'epoch':str(i)}
    torch.save(state, "log/mean_loss:"+str(mean_loss)+'acc:'+str(acc)+'epoch:'+str(i)+'.pth') 