In [1]:
import torch
from torch.utils.data.dataset import ConcatDataset
from transformers import BertTokenizer, BertModel
import pandas as pd
from torch.utils.data.sampler import RandomSampler
import math
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from sklearn.metrics import accuracy_score
from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
#set the device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [3]:
# load dbs into memory
data_1 = pd.read_csv("/home/labmol/Documents/Francisco/curated_cytosafe_multiparam/curated_CTG_binary.csv")
data_2 = pd.read_csv("/home/labmol/Documents/Francisco/curated_cytosafe_multiparam/curated_Cell_number_binary.csv")
data_3 = pd.read_csv("/home/labmol/Documents/Francisco/curated_cytosafe_multiparam/curated_Nuclear_area_binary.csv")
data_4 = pd.read_csv("/home/labmol/Documents/Francisco/curated_cytosafe_multiparam/curated_Nuclear_intensity_binary.csv")
data_5 = pd.read_csv("/home/labmol/Documents/Francisco/curated_cytosafe_multiparam/curated_Nuclear_membrane_permeability_binary.csv")

In [4]:
# encode smiles
def tokenize_smiles(smiles_list, labels, tokenizer):

    tokenized_inputs = []

    for smiles, label in zip(smiles_list, labels):
        # Tokenize the SMILES string
        encoded_dict = tokenizer.encode_plus(
            smiles,
            add_special_tokens=True,
            max_length=128,  # Adjust as needed
            padding='max_length',
            truncation=True,
            return_attention_mask=True,
            return_tensors='pt'
        )

        # Append tokenized input IDs and attention mask to the list
        tokenized_inputs.append({
            'input_ids': encoded_dict['input_ids'].flatten(),
            'attention_mask': encoded_dict['attention_mask'].flatten(),
            'label': torch.tensor(label, dtype=torch.long)
        })

    return tokenized_inputs

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

tokenized_inputs_data_1 = tokenize_smiles(data_1['final_smiles'], data_1['Outcome'], tokenizer)
tokenized_inputs_data_2 = tokenize_smiles(data_2['final_smiles'], data_2['Outcome'], tokenizer)
tokenized_inputs_data_3 = tokenize_smiles(data_3['final_smiles'], data_3['Outcome'], tokenizer)
tokenized_inputs_data_4 = tokenize_smiles(data_4['final_smiles'], data_4['Outcome'], tokenizer)
tokenized_inputs_data_5 = tokenize_smiles(data_5['final_smiles'], data_5['Outcome'], tokenizer)

In [5]:
tokenized_inputs_data_1[0]

{'input_ids': tensor([  101, 27166,  1006,  1039,  1007, 10507, 10085,  1006, 27723,  9468,
          2278,  1006, 18856,  1007, 10507,  2487,  1007, 27723,  9468,  9468,
          2078,  2487,   102,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,   

In [6]:
class CustomDataset():
    def __init__(self, tokenized_inputs, task_idx):
        self.tokenized_inputs = tokenized_inputs
        self.task_idx = task_idx

    def __getitem__(self, idx):
        return {
            'input_ids': self.tokenized_inputs[idx]['input_ids'],
            'attention_mask': self.tokenized_inputs[idx]['attention_mask'],
            'labels': self.tokenized_inputs[idx]['label'],
            'task_idx': self.task_idx
        }

    def __len__(self):
        # change this to return number of samples in your dataset
        return len(self.tokenized_inputs)
    
datasets = [tokenized_inputs_data_1, tokenized_inputs_data_2, tokenized_inputs_data_3, tokenized_inputs_data_4, tokenized_inputs_data_5]
idx_datasets = [0, 1, 2, 3, 4]
concat_dataset = ConcatDataset([CustomDataset(dataset, idx) for dataset, idx in zip(datasets, idx_datasets)])

In [7]:
class BatchSchedulerSampler(torch.utils.data.sampler.Sampler):
    """
    iterate over tasks and provide a random batch per task in each mini-batch
    """
    def __init__(self, dataset, batch_size):
        self.dataset = dataset
        self.batch_size = batch_size
        self.number_of_datasets = len(dataset.datasets)
        self.largest_dataset_size = max([len(cur_dataset.tokenized_inputs) for cur_dataset in dataset.datasets])

    def __len__(self):
        return self.batch_size * math.ceil(self.largest_dataset_size / self.batch_size) * len(self.dataset.datasets)

    def __iter__(self):
        samplers_list = []
        sampler_iterators = []
        for dataset_idx in range(self.number_of_datasets):
            cur_dataset = self.dataset.datasets[dataset_idx]
            sampler = RandomSampler(cur_dataset)
            samplers_list.append(sampler)
            cur_sampler_iterator = sampler.__iter__()
            sampler_iterators.append(cur_sampler_iterator)

        push_index_val = [0] + self.dataset.cumulative_sizes[:-1]
        step = self.batch_size * self.number_of_datasets
        samples_to_grab = self.batch_size
        # for this case we want to get all samples in dataset, this force us to resample from the smaller datasets
        epoch_samples = self.largest_dataset_size * self.number_of_datasets

        final_samples_list = []  # this is a list of indexes from the combined dataset
        for _ in range(0, epoch_samples, step):
            for i in range(self.number_of_datasets):
                cur_batch_sampler = sampler_iterators[i]
                cur_samples = []
                for _ in range(samples_to_grab):
                    try:
                        cur_sample_org = cur_batch_sampler.__next__()
                        cur_sample = cur_sample_org + push_index_val[i]
                        cur_samples.append(cur_sample)
                    except StopIteration:
                        # got to the end of iterator - restart the iterator and continue to get samples
                        # until reaching "epoch_samples"
                        sampler_iterators[i] = samplers_list[i].__iter__()
                        cur_batch_sampler = sampler_iterators[i]
                        cur_sample_org = cur_batch_sampler.__next__()
                        cur_sample = cur_sample_org + push_index_val[i]
                        cur_samples.append(cur_sample)
                final_samples_list.extend(cur_samples)

        return iter(final_samples_list)

In [8]:
batch_size = 32

# dataloader with BatchSchedulerSampler
dataloader = torch.utils.data.DataLoader(dataset=concat_dataset, sampler=BatchSchedulerSampler(dataset=concat_dataset, batch_size=batch_size), batch_size=batch_size, shuffle=False)

In [9]:
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
from transformers import BertModel, BertTokenizer, BertConfig  # Adjust imports based on your transformer version

# Define your TaskSpecificMultiTaskModel if not already defined
class TaskSpecificMultiTaskModel(nn.Module):
    def __init__(self, bert_model):
        super(TaskSpecificMultiTaskModel, self).__init__()
        self.bert = bert_model
        self.classifier = nn.Linear(bert_model.config.hidden_size, 1)  # Example classifier layer for binary classification

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        logits = self.classifier(outputs.last_hidden_state[:, 0, :])  # Example: taking the CLS token representation
        return logits

def train(model, dataloader, criterion, optimizer, device):
    model.train()
    total_loss = 0.0
    total_correct = 0
    total_samples = 0
    
    # Use tqdm to track progress
    with tqdm(total=len(dataloader), desc='Training') as pbar:
        for batch in dataloader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].float().to(device)  # Assuming label key is 'labels'

            optimizer.zero_grad()

            logits = model(input_ids, attention_mask).squeeze()  # Ensure logits are of shape [batch_size]
            
            # Calculate binary cross-entropy loss
            loss = criterion(logits, labels)
            total_loss += loss.item()

            loss.backward()
            optimizer.step()

            # For accuracy calculation
            predictions = (logits > 0.5).float()  # Adjust threshold as per your task
            total_correct += (predictions == labels).sum().item()
            total_samples += labels.size(0)

            # Update tqdm progress bar
            pbar.update(1)
            pbar.set_postfix({'Loss': loss.item()})

    average_loss = total_loss / len(dataloader)
    accuracy = total_correct / total_samples

    return average_loss, accuracy

# Example usage
if __name__ == "__main__":
    # Initialize BERT model and tokenizer
    bert_model = BertModel.from_pretrained('bert-base-uncased')
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
    model = TaskSpecificMultiTaskModel(bert_model).to(device)

    # Example criterion and optimizer
    criterion = nn.BCEWithLogitsLoss()  # Use BCEWithLogitsLoss for numerical stability
    optimizer = optim.Adam(model.parameters(), lr=1e-5)

    # Assuming you have a DataLoader named dataloader and device defined
    num_epochs = 5
    for epoch in range(num_epochs):
        train_loss, train_accuracy = train(model, dataloader, criterion, optimizer, device)
        print(f"Epoch {epoch + 1}/{num_epochs}:")
        print(f"  Train Total Loss: {train_loss:.4f}, Train Total Accuracy: {train_accuracy:.4f}")



Training: 100%|██████████| 1930/1930 [04:18<00:00,  7.47it/s, Loss=0.496] 


Epoch 1/5:
  Train Total Loss: 0.2905, Train Total Accuracy: 0.8996


Training: 100%|██████████| 1930/1930 [04:20<00:00,  7.41it/s, Loss=0.195] 


Epoch 2/5:
  Train Total Loss: 0.2514, Train Total Accuracy: 0.9040


Training: 100%|██████████| 1930/1930 [04:19<00:00,  7.43it/s, Loss=0.349] 


Epoch 3/5:
  Train Total Loss: 0.2204, Train Total Accuracy: 0.9141


Training: 100%|██████████| 1930/1930 [04:20<00:00,  7.41it/s, Loss=0.159] 


Epoch 4/5:
  Train Total Loss: 0.1989, Train Total Accuracy: 0.9210


Training: 100%|██████████| 1930/1930 [04:20<00:00,  7.41it/s, Loss=0.186] 

Epoch 5/5:
  Train Total Loss: 0.1866, Train Total Accuracy: 0.9261





In [10]:
torch.save(model.state_dict(), 'model_agrvai.pth')