In [1]:
import pandas as pd
pd.set_option('display.max_colwidth', 500)

df = pd.read_csv('data.csv', index_col=[0])

In [2]:
df = df[['transcription', 'keywords']]

In [3]:
import torch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [4]:
# !pip install transformers
# !pip install tensorboard
# !pip install tensorboardx

In [5]:
from sklearn.model_selection import train_test_split
from transformers import T5Tokenizer, T5ForConditionalGeneration
import torch
from torch.utils.data import Dataset, DataLoader

In [6]:
df = df.dropna()

In [7]:
# Load the pretrained T5 model and tokenizer
tokenizer = T5Tokenizer.from_pretrained('t5-small')
model = T5ForConditionalGeneration.from_pretrained('t5-small').to(device)

In [24]:
# Define a custom dataset for training
class Dataset(Dataset):
    def __init__(self, input_texts, target_queries, tokenizer, task_prefix):
        self.input_texts = input_texts
        self.target_queries = target_queries
        self.tokenizer = tokenizer
        self.task_prefix = task_prefix

    def __len__(self):
        return len(self.input_texts)

    def __getitem__(self, index):
        
        input_text = self.task_prefix + self.input_texts[index]
        target_query = self.target_queries[index]

        input_encoding = self.tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True, padding="max_length")
        target_encoding = self.tokenizer(target_query, return_tensors="pt", max_length=512, truncation=True, padding="max_length")
        
        return {
            'input_ids': input_encoding.input_ids.squeeze(0),
            'attention_mask': input_encoding.attention_mask.squeeze(0),
            'labels': target_encoding.input_ids.squeeze(0),
        }

In [25]:
# Load the labeled dataset
df1 = df.sample(500)
input_texts = df1.transcription.values # List of input texts
target_queries = df1.keywords.values  # List of corresponding target SQL queries

# Split the dataset into train and validation sets
train_input_texts, val_input_texts, train_target_queries, val_target_queries = train_test_split(input_texts, target_queries, test_size=0.2, random_state=42)

In [26]:
# Create instances of the custom dataset
task_prefix = 'Create a summary for '
train_dataset = Dataset(train_input_texts, train_target_queries, tokenizer, task_prefix)
val_dataset = Dataset(val_input_texts, val_target_queries, tokenizer, task_prefix)

In [27]:
# Define the training hyperparameters
BATCH_SIZE = 1
NUM_EPOCHS = 3
LEARNING_RATE = 0.001

# Define the optimier and loss function
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.01)
criterion = torch.nn.CrossEntropyLoss()

In [28]:
# Create data loaders
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)

In [29]:
from tqdm.notebook import tqdm
tqdm.pandas()

In [30]:
# Training loop
for epoch in tqdm(range(NUM_EPOCHS)):
    model.train()
    for batch in tqdm(train_dataloader):
        input_ids = batch['input_ids']
        attention_mask = batch['attention_mask']
        labels = batch['labels']

        optimizer.zero_grad()
        outputs = model(input_ids=input_ids.to(device), attention_mask=attention_mask.to(device), labels=labels.to(device))
        loss = outputs.loss
        loss.backward()
        optimizer.step()

    # Evaluation on validation set
    model.eval()
    total_val_loss = 0
    for batch in tqdm(val_dataloader):
        input_ids = batch['input_ids']
        attention_mask = batch['attention_mask']
        labels = batch['labels']

        with torch.no_grad():
            outputs = model(input_ids=input_ids.to(device), attention_mask=attention_mask.to(device), labels=labels.to(device))
            val_loss = outputs.loss
            total_val_loss += val_loss.item()

    avg_val_loss = total_val_loss / len(val_dataloader)
    
    # Print progress
    print(f'Epoch: {epoch+1}, Validation Loss: {avg_val_loss:.4f}')

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/400 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

Validation Loss 0.2441992333624512 0
Epoch: 1, Validation Loss: 0.2442


  0%|          | 0/400 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

Validation Loss 0.2392995275557041 1
Epoch: 2, Validation Loss: 0.2393


  0%|          | 0/400 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

Validation Loss 0.2375073814485222 2
Epoch: 3, Validation Loss: 0.2375


In [31]:
input_text = task_prefix + df.transcription.iloc[4]
input_encoding = tokenizer([input_text], return_tensors="pt", max_length=512, truncation=True, padding="max_length")

input_ids = input_encoding['input_ids']
attention_mask = input_encoding['attention_mask']

with torch.no_grad():
    outputs = model.generate(input_ids=input_ids.to(device), attention_mask=attention_mask.to(device))
    print(input_text)
    print(df.keywords.iloc[4])
    print(tokenizer.decode(outputs[0]))

Create a summary for 1.  The left ventricular cavity size and wall thickness appear normal.  The wall motion and left ventricular systolic function appears hyperdynamic with estimated ejection fraction of 70% to 75%.  There is near-cavity obliteration seen.  There also appears to be increased left ventricular outflow tract gradient at the mid cavity level consistent with hyperdynamic left ventricular systolic function.  There is abnormal left ventricular relaxation pattern seen as well as elevated left atrial pressures seen by Doppler examination.,2.  The left atrium appears mildly dilated.,3.  The right atrium and right ventricle appear normal.,4.  The aortic root appears normal.,5.  The aortic valve appears calcified with mild aortic valve stenosis, calculated aortic valve area is 1.3 cm square with a maximum instantaneous gradient of 34 and a mean gradient of 19 mm.,6.  There is mitral annular calcification extending to leaflets and supportive structures with thickening of mitral va

