In [1]:
import pandas as pd
import torch
import torch.nn as nn
from transformers import BertTokenizer, BertModel
from torch.utils.data import Dataset, DataLoader
from transformers import Trainer, TrainingArguments

In [11]:
# Load data
# 60 train, 20 dev, 20 test
train_df = pd.read_csv('data/processed/phee/ace/train_w_test_tag_new_mapped.csv')
dev_df = pd.read_csv('data/processed/phee/ace/dev_w_test_tag_new_mapped.csv')
test_df = pd.read_csv('data/processed/phee/ace/test_w_test_tag_new_mapped.csv')


# train_df = train_df[['Sentence', 'Med_Tag']]
# train_df.rename(columns={"Sentence": "sentence", "Med_Tag": "tag"}, inplace=True)
# dev_df = dev_df[['Sentence', 'Med_Tag']]
# dev_df.rename(columns={"Sentence": "sentence", "Med_Tag": "tag"}, inplace=True)
# test_df = test_df[['Sentence', 'Med_Tag']]
# test_df.rename(columns={"Sentence": "sentence", "Med_Tag": "tag"}, inplace=True)

train_df['sentence'] = train_df['sentence'].apply(lambda x: x.split())
train_df['tag'] = train_df['tag'].apply(lambda x: x.split())
dev_df['sentence'] = dev_df['sentence'].apply(lambda x: x.split())
dev_df['tag'] = dev_df['tag'].apply(lambda x: x.split())
test_df['sentence'] = test_df['sentence'].apply(lambda x: x.split())
test_df['tag'] = test_df['tag'].apply(lambda x: x.split())

train_dev_df = pd.concat([train_df, dev_df], ignore_index=True)

# save to csv
train_dev_df.to_csv('data/processed/phee/ace/final_train.csv', index=False) # final_train is train+dev
test_df.to_csv('data/processed/phee/ace/final_dev.csv', index=False) # final_dev is test
test_df.to_csv('data/processed/phee/ace/final_test.csv', index=False) # final_test is test

In [12]:
train_dev_df

Unnamed: 0,sentence,tag
0,"[objective, :, to, test, the, hypothesis, that...","[O, O, O, O, O, O, O, O, O, O, O, O, O, O, O, ..."
1,"[an, evaluation, of, ovarian, structure, and, ...","[O, I-Test, O, O, O, O, O, O, O, O, O, I-Backg..."
2,"[phenobarbital, hepatotoxicity, in, an, 8, -, ...","[I-Treatment, I-Problem, O, O, I-Background, O..."
3,"[the, authors, report, a, case, of, balint, sy...","[O, O, O, O, O, O, I-Problem, I-Problem, O, I-..."
4,"[according, to, the, naranjo, probability, sca...","[O, O, O, O, I-Test, I-Test, O, I-Treatment, O..."
...,...,...
3706,"[successful, challenge, with, clozapine, in, a...","[O, O, O, I-Treatment, O, O, O, O, I-Problem, O]"
3707,"[case, summary, :, a, 57, -, year, -, old, fem...","[O, O, O, O, I-Background, O, I-Background, O,..."
3708,"[acute, intravascular, hemolysis, developed, w...","[I-Problem, I-Problem, I-Problem, O, O, O, I-P..."
3709,"[intravitreal, triamcinolone, may, have, had, ...","[I-Treatment, I-Treatment, O, O, O, O, O, O, O..."


In [31]:
label2id = {'O': 0, 'I-Treatment': 1, 'I-Test': 2, 'I-Problem': 3, 'I-Background': 4, 'I-Other': 5}

In [32]:

# Define a custom dataset
class NERDataset(Dataset):
    def __init__(self, sentences, tags, tokenizer, max_len):
        self.sentences = sentences
        self.tags = tags
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.sentences)

    def __getitem__(self, idx):
        sentence = self.sentences[idx]
        tags = self.tags[idx]
        labels = [label2id[tag] for tag in tags]

        encoding = self.tokenizer.encode_plus(
            sentence,
            add_special_tokens=True,
            max_length=self.max_len,
            return_token_type_ids=False,
            padding='max_length',
            truncation=True,
            return_attention_mask=True,
            return_tensors='pt',
        )
        labels = torch.tensor(labels, dtype=torch.long)
        labels = labels[:self.max_len]
        
        return {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'labels': labels
        }

# Define the model
class BertLSTM(nn.Module):
    def __init__(self, bert_model_name, lstm_hidden_size, num_tags):
        super(BertLSTM, self).__init__()
        self.bert = BertModel.from_pretrained(bert_model_name)
        self.lstm = nn.LSTM(input_size=768, hidden_size=lstm_hidden_size, batch_first=True, dropout=0.1, bidirectional=True)
        self.classifier = nn.Linear(lstm_hidden_size, num_tags)

    def forward(self, input_ids, attention_mask):
        embeddings = self.bert(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state
        lstm_out, _ = self.lstm(embeddings)
        logits = self.classifier(lstm_out)
        return logits

# Preparing the tokenizer and model
bert_model_name = 'bert-base-uncased'
tokenizer = BertTokenizer.from_pretrained(bert_model_name)
model = BertLSTM(bert_model_name, lstm_hidden_size=50, num_tags=6)

# Prepare the data
max_len = 128  # You might need to adjust this
train_dataset = NERDataset(train_df['sentence'], train_df['tag'], tokenizer, max_len)

# Training arguments
training_args = TrainingArguments(
    output_dir='./results',
    num_train_epochs=1,
    per_device_train_batch_size=2,
    logging_dir='./logs',
)

# Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
)

# Train the model
trainer.train()


  result = _VF.lstm(input, hx, self._flat_weights, self.bias, self.num_layers,


RuntimeError: Caught RuntimeError in replica 0 on device 0.
Original Traceback (most recent call last):
  File "/home/ozan/.conda/envs/medh/lib/python3.8/site-packages/torch/nn/parallel/parallel_apply.py", line 61, in _worker
    output = module(*input, **kwargs)
  File "/home/ozan/.conda/envs/medh/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/tmp/ipykernel_3988/2252036063.py", line 47, in forward
    logits = self.classifier(lstm_out)
  File "/home/ozan/.conda/envs/medh/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/ozan/.conda/envs/medh/lib/python3.8/site-packages/torch/nn/modules/linear.py", line 114, in forward
    return F.linear(input, self.weight, self.bias)
RuntimeError: mat1 and mat2 shapes cannot be multiplied (256x100 and 50x6)
