In [1]:
import pandas as pd
pd.set_option('display.max_colwidth', 500)

df = pd.read_csv('gs://data-healthcare/medical_samples.csv', index_col=[0])

In [2]:
df = df[['transcription', 'keywords']]

In [3]:
import torch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [4]:
# !pip install transformers
# !pip install tensorboard
# !pip install tensorboardx

In [5]:
from sklearn.model_selection import train_test_split
from transformers import T5Tokenizer, T5ForConditionalGeneration
import torch
from torch.utils.data import Dataset, DataLoader

In [6]:
df = df.dropna()

In [None]:
# Load the pretrained T5 model and tokenizer
tokenizer = T5Tokenizer.from_pretrained('t5-small')
model = T5ForConditionalGeneration.from_pretrained('t5-small').to(device)

In [None]:
# Define a custom dataset for training
class Dataset(Dataset):
    def __init__(self, input_texts, target_queries, tokenizer, task_prefix):
        self.input_texts = input_texts
        self.target_queries = target_queries
        self.tokenizer = tokenizer
        self.task_prefix = task_prefix

    def __len__(self):
        return len(self.input_texts)

    def __getitem__(self, index):
        
        input_text = self.task_prefix + self.input_texts[index]
        target_query = self.target_queries[index]

        input_encoding = self.tokenizer([input_text], return_tensors="pt", max_length=512, truncation=True, padding="max_length")
        target_encoding = self.tokenizer([target_query], return_tensors="pt", max_length=512, truncation=True, padding="max_length")
        
        return {
            'input_ids': input_encoding.input_ids.squeeze(0),
            'attention_mask': input_encoding.attention_mask.squeeze(0),
            'labels': target_encoding.input_ids.squeeze(0),
        }

In [None]:
# Load the labeled dataset
df1 = df.sample(100)
input_texts = df1.transcription.values # List of input texts
target_queries = df1.keywords.values  # List of corresponding target SQL queries

# Split the dataset into train and validation sets
train_input_texts, val_input_texts, train_target_queries, val_target_queries = train_test_split(input_texts, target_queries, test_size=0.2, random_state=42)

In [None]:
# Create instances of the custom dataset
task_prefix = 'Create a summary for '
train_dataset = Dataset(train_input_texts, train_target_queries, tokenizer, task_prefix)
val_dataset = Dataset(val_input_texts, val_target_queries, tokenizer, task_prefix)

In [None]:
# Define the training hyperparameters
BATCH_SIZE = 8
NUM_EPOCHS = 3
LEARNING_RATE = 0.01

# Define the optimier and loss function
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.1)
criterion = torch.nn.CrossEntropyLoss()

In [None]:
# Create data loaders
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)

In [None]:
from tqdm.notebook import tqdm
tqdm.pandas()

In [None]:
# Training loop
for epoch in tqdm(range(NUM_EPOCHS)):
    model.train()
    for batch in tqdm(train_dataloader):
        input_ids = batch['input_ids']
        attention_mask = batch['attention_mask']
        labels = batch['labels']

        optimizer.zero_grad()
        outputs = model(input_ids=input_ids.to(device), attention_mask=attention_mask.to(device), labels=labels.to(device))
        loss = outputs.loss
        loss.backward()
        optimizer.step()
        
        # Write training loss to TensorBoard
        writer.add_scalar('Training Loss', loss.item(), epoch)

    # Evaluation on validation set
    model.eval()
    total_val_loss = 0
    for batch in tqdm(val_dataloader):
        input_ids = batch['input_ids']
        attention_mask = batch['attention_mask']
        labels = batch['labels']

        with torch.no_grad():
            outputs = model(input_ids=input_ids.to(device), attention_mask=attention_mask.to(device), labels=labels.to(device))
            val_loss = outputs.loss
            total_val_loss += val_loss.item()

    avg_val_loss = total_val_loss / len(val_dataloader)
    
    # Write validation loss to TensorBoard
    print('Validation Loss', avg_val_loss, epoch)
    
    # Print progress
    print(f'Epoch: {epoch+1}, Validation Loss: {avg_val_loss:.4f}')