<a href="https://colab.research.google.com/github/itskutush/Attention-based-automated-radiology-report-generation/blob/main/Transformer%20NMT%20with%20Attention.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import numpy as np
import pandas as pd

In [2]:
df = pd.read_csv("/content/translated_output.csv")

In [3]:
df = df.dropna().reset_index(drop=True)
df

Unnamed: 0,findings,findings_translated
0,The cardiac silhouette and mediastinum size ar...,कार्डियक सिल्हूट और मीडियास्टिनम का आकार सामान...
1,Borderline cardiomegaly. Midline sternotomy XX...,बॉर्डरलाइन कार्डियोमेगाली.मिडलाइन स्टर्नोटॉमी ...
2,There are diffuse bilateral interstitial and a...,क्रोनिक ऑब्सट्रक्टिव फेफड़े की बीमारी और बुलस ...
3,The cardiomediastinal silhouette and pulmonary...,कार्डियोमीडियास्टिनल सिल्हूट और फुफ्फुसीय वाहि...
4,Heart size and mediastinal contour are within ...,हृदय का आकार और मीडियास्टीनल रूपरेखा सामान्य स...
...,...,...
3332,The heart is mildly enlarged. Left hemidiaphra...,हृदय हल्का सा बड़ा हो गया है।बायां हेमिडियाफ्र...
3333,Similar mild cardiomegaly. Of the pulmonary va...,समान हल्का कार्डियोमेगाली।फुफ्फुसीय संवहनीकरण ...
3334,The cardiomediastinal silhouette and pulmonary...,कार्डियोमीडियास्टिनल सिल्हूट और फुफ्फुसीय वाहि...
3335,The lungs are clear. Heart size is normal. No ...,फेफड़े साफ हैं.हृदय का आकार सामान्य है.कोई न्य...


In [39]:
import os
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
from tqdm import tqdm
import json
from collections import Counter

# Optional BLEU score evaluation
try:
    import nltk
    from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
    nltk.download('punkt')
    # Download punkt_tab resource for sentence_bleu
    nltk.download('punkt_tab')
    HAVE_NLTK = True
except Exception:
    HAVE_NLTK = False

# -------------------------
# 1) Hyperparameters
# -------------------------
CSV_PATH = "./translated_output.csv"  # Adjust path as needed
BATCH_SIZE = 32
EMBEDDING_DIM = 256
NUM_HEADS = 8
NUM_LAYERS = 4
DFF = 512  # Feed-forward dimension
DROPOUT = 0.1
EPOCHS = 50
MAX_VOCAB_SIZE = None  # None = keep all words; set int to limit
SAVE_DIR = "./transformer_weights"
os.makedirs(SAVE_DIR, exist_ok=True)
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# -------------------------
# 2) Preprocess utility
# -------------------------
def preprocess_text(s: str) -> str:
    if not isinstance(s, str):
        s = ""
    return s.strip().lower()

# -------------------------
# 3) Load and prepare data
# -------------------------
def load_data(csv_path):
    df = pd.read_csv(csv_path)
    df = df.dropna(subset=['findings', 'findings_translated']).reset_index(drop=True)
    df['eng'] = df['findings'].astype(str).apply(preprocess_text)
    df['hin'] = df['findings_translated'].astype(str).apply(lambda x: f"<start> {x.strip()} <end>")
    return df['eng'].tolist(), df['hin'].tolist(), df

eng_texts, hin_texts, df = load_data(CSV_PATH)

# -------------------------
# 4) Tokenizers
# -------------------------
class Tokenizer:
    def __init__(self, texts, max_vocab_size=None):
        self.word2idx = {"<pad>": 0, "<unk>": 1}
        self.idx2word = {0: "<pad>", 1: "<unk>"}
        word_counts = Counter()
        for text in texts:
            words = text.split()
            word_counts.update(words)

        # Build vocabulary
        vocab = [word for word, _ in word_counts.most_common(max_vocab_size)]
        for i, word in enumerate(vocab, start=2):
            self.word2idx[word] = i
            self.idx2word[i] = word

    def encode(self, text):
        return [self.word2idx.get(word, self.word2idx["<unk>"]) for word in text.split()]

    def decode(self, indices):
        return " ".join(self.idx2word.get(idx, "<unk>") for idx in indices if idx != 0)

    def vocab_size(self):
        return len(self.word2idx)

eng_tokenizer = Tokenizer(eng_texts, MAX_VOCAB_SIZE)
hin_tokenizer = Tokenizer(hin_texts, MAX_VOCAB_SIZE)

if "<start>" not in hin_tokenizer.word2idx or "<end>" not in hin_tokenizer.word2idx:
    raise ValueError("Hindi tokenizer missing <start> or <end> tokens")

# -------------------------
# 5) Dataset and DataLoader
# -------------------------
class TranslationDataset(Dataset):
    def __init__(self, eng_texts, hin_texts, eng_tokenizer, hin_tokenizer):
        self.eng_texts = eng_texts
        self.hin_texts = hin_texts
        self.eng_tokenizer = eng_tokenizer
        self.hin_tokenizer = hin_tokenizer

    def __len__(self):
        return len(self.eng_texts)

    def __getitem__(self, idx):
        eng = self.eng_tokenizer.encode(self.eng_texts[idx])
        hin = self.hin_tokenizer.encode(self.hin_texts[idx])
        return torch.tensor(eng, dtype=torch.long), torch.tensor(hin, dtype=torch.long)

def collate_fn(batch):
    eng_batch, hin_batch = zip(*batch)
    eng_batch = pad_sequence(eng_batch, batch_first=True, padding_value=0)
    hin_batch = pad_sequence(hin_batch, batch_first=True, padding_value=0)
    hin_input = hin_batch[:, :-1]  # Decoder input (without <end>)
    hin_target = hin_batch[:, 1:]  # Decoder target (without <start>)
    return eng_batch, hin_input, hin_target

dataset = TranslationDataset(eng_texts, hin_texts, eng_tokenizer, hin_tokenizer)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn)
steps_per_epoch = len(dataset) // BATCH_SIZE

# -------------------------
# 6) Transformer Model
# -------------------------
class TransformerNMT(nn.Module):
    def __init__(self, src_vocab_size, tgt_vocab_size, d_model, num_heads, num_layers, dff, max_seq_len, dropout=0.1):
        super(TransformerNMT, self).__init__()
        self.src_embedding = nn.Embedding(src_vocab_size, d_model)
        self.tgt_embedding = nn.Embedding(tgt_vocab_size, d_model)
        self.pos_encoding = self.positional_encoding(max_seq_len, d_model)
        self.transformer = nn.Transformer(
            d_model=d_model,
            nhead=num_heads,
            num_encoder_layers=num_layers,
            num_decoder_layers=num_layers,
            dim_feedforward=dff,
            dropout=dropout,
            batch_first=True
        )
        self.fc = nn.Linear(d_model, tgt_vocab_size)
        self.d_model = d_model
        self.dropout = nn.Dropout(dropout)

    def positional_encoding(self, max_len, d_model):
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        return pe.unsqueeze(0).to(DEVICE)

    def forward(self, src, tgt, src_mask=None, tgt_mask=None, src_padding_mask=None, tgt_padding_mask=None):
        src = self.src_embedding(src) * np.sqrt(self.d_model)
        src = src + self.pos_encoding[:, :src.size(1), :].to(DEVICE)
        src = self.dropout(src)

        tgt = self.tgt_embedding(tgt) * np.sqrt(self.d_model)
        tgt = tgt + self.pos_encoding[:, :tgt.size(1), :].to(DEVICE)
        tgt = self.dropout(tgt)

        output = self.transformer(
            src, tgt,
            src_mask=src_mask,
            tgt_mask=tgt_mask,
            src_key_padding_mask=src_padding_mask,
            tgt_key_padding_mask=tgt_padding_mask,
            memory_key_padding_mask=src_padding_mask
        )
        output = self.fc(output)
        return output

# -------------------------
# 7) Masking
# -------------------------
def create_padding_mask(seq):
    return (seq == 0).to(DEVICE)

def create_look_ahead_mask(size):
    mask = torch.triu(torch.ones(size, size), diagonal=1).bool()
    return mask.to(DEVICE)

# -------------------------
# 8) Instantiate model
# -------------------------
max_len = max(max(len(eng_tokenizer.encode(t)) for t in eng_texts),
              max(len(hin_tokenizer.encode(t)) for t in hin_texts))
model = TransformerNMT(
    src_vocab_size=eng_tokenizer.vocab_size(),
    tgt_vocab_size=hin_tokenizer.vocab_size(),
    d_model=EMBEDDING_DIM,
    num_heads=NUM_HEADS,
    num_layers=NUM_LAYERS,
    dff=DFF,
    max_seq_len=max_len,
    dropout=DROPOUT
).to(DEVICE)

optimizer = optim.Adam(model.parameters(), lr=0.0001)
criterion = nn.CrossEntropyLoss(ignore_index=0)  # Ignore padding

# -------------------------
# 9) Training loop
# -------------------------
def train_step(model, src, tgt_input, tgt_target):
    model.train()
    optimizer.zero_grad()

    src_padding_mask = create_padding_mask(src)
    tgt_padding_mask = create_padding_mask(tgt_input)
    tgt_mask = create_look_ahead_mask(tgt_input.size(1))

    output = model(src, tgt_input,
                   tgt_mask=tgt_mask,
                   src_padding_mask=src_padding_mask,
                   tgt_padding_mask=tgt_padding_mask)

    loss = criterion(output.view(-1, output.size(-1)), tgt_target.view(-1))
    loss.backward()
    optimizer.step()
    return loss.item()

for epoch in range(EPOCHS):
    total_loss = 0.0
    prog = tqdm(enumerate(dataloader), total=steps_per_epoch, desc=f"Epoch {epoch+1}/{EPOCHS}")
    for i, (src, tgt_input, tgt_target) in prog:
        if i >= steps_per_epoch:
            break
        src, tgt_input, tgt_target = src.to(DEVICE), tgt_input.to(DEVICE), tgt_target.to(DEVICE)
        loss = train_step(model, src, tgt_input, tgt_target)
        total_loss += loss
        if (i + 1) % 50 == 0:
            prog.set_postfix({'batch_loss': loss})
    epoch_loss = total_loss / steps_per_epoch
    print(f"\nEpoch {epoch+1} Loss {epoch_loss:.4f}")
    torch.save(model.state_dict(), os.path.join(SAVE_DIR, f"transformer_epoch{epoch+1}.pth"))

# -------------------------
# 10) Inference
# -------------------------
def evaluate(sentence, max_len_out=max_len):
    model.eval()
    sentence = preprocess_text(sentence)
    src = torch.tensor([eng_tokenizer.encode(sentence)], dtype=torch.long).to(DEVICE)
    src = pad_sequence([src[0]], batch_first=True, padding_value=0).to(DEVICE)

    tgt = torch.tensor([[hin_tokenizer.word2idx["<start>"]]], dtype=torch.long).to(DEVICE)
    output_tokens = []

    with torch.no_grad():
        for _ in range(max_len_out):
            src_padding_mask = create_padding_mask(src)
            tgt_padding_mask = create_padding_mask(tgt)
            tgt_mask = create_look_ahead_mask(tgt.size(1))

            output = model(src, tgt,
                          tgt_mask=tgt_mask,
                          src_padding_mask=src_padding_mask,
                          tgt_padding_mask=tgt_padding_mask)

            predicted_id = torch.argmax(output[:, -1, :], dim=-1).item()
            if predicted_id == hin_tokenizer.word2idx["<end>"]:
                break
            output_tokens.append(predicted_id)
            tgt = torch.cat([tgt, torch.tensor([[predicted_id]], dtype=torch.long).to(DEVICE)], dim=1)

    return hin_tokenizer.decode(output_tokens)

def translate(sentence):
    return evaluate(sentence)

# -------------------------
# 11) Test translations
# -------------------------
examples = [
    "Borderline cardiomegaly. Midline sternotomy noted.",
    "The cardiac silhouette and mediastinum size are within normal limits."
]
for ex in examples:
    print("ENG:", ex)
    print("HIN:", translate(ex))
    print()

# -------------------------
# 12) BLEU evaluation
# -------------------------
if HAVE_NLTK:
    smooth = SmoothingFunction().method1
    n_samples = min(50, len(df))
    refs = []
    hyps = []
    for i in range(n_samples):
        cand = translate(df['eng'].iloc[i])
        ref = df['hin'].iloc[i].replace('<start>', '').replace('<end>', '').strip()
        ref_tok = nltk.word_tokenize(ref)
        cand_tok = nltk.word_tokenize(cand)
        refs.append([ref_tok])
        hyps.append(cand_tok)
    bleu_scores = [sentence_bleu(refs[i], hyps[i], smoothing_function=smooth) for i in range(len(hyps))]
    print("Avg sentence BLEU (sample):", np.mean(bleu_scores))
else:
    print("nltk not available. Skip BLEU.")

# -------------------------
# 13) Save tokenizers
# -------------------------
with open(os.path.join(SAVE_DIR, "eng_tokenizer.json"), "w", encoding='utf-8') as f:
    json.dump({"word2idx": eng_tokenizer.word2idx, "idx2word": eng_tokenizer.idx2word}, f)
with open(os.path.join(SAVE_DIR, "hin_tokenizer.json"), "w", encoding='utf-8') as f:
    json.dump({"word2idx": hin_tokenizer.word2idx, "idx2word": hin_tokenizer.idx2word}, f)

print("Training complete.")

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package punkt_tab to /root/nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!
Epoch 1/50: 100%|██████████| 104/104 [00:06<00:00, 15.98it/s, batch_loss=5.7]



Epoch 1 Loss 6.6586


Epoch 2/50: 100%|██████████| 104/104 [00:06<00:00, 15.75it/s, batch_loss=4.67]



Epoch 2 Loss 4.9714


Epoch 3/50: 100%|██████████| 104/104 [00:06<00:00, 16.28it/s, batch_loss=4.21]



Epoch 3 Loss 4.2926


Epoch 4/50: 100%|██████████| 104/104 [00:06<00:00, 15.77it/s, batch_loss=3.81]



Epoch 4 Loss 3.8422


Epoch 5/50: 100%|██████████| 104/104 [00:06<00:00, 15.57it/s, batch_loss=2.94]



Epoch 5 Loss 3.5247


Epoch 6/50: 100%|██████████| 104/104 [00:06<00:00, 15.55it/s, batch_loss=3.84]



Epoch 6 Loss 3.2712


Epoch 7/50: 100%|██████████| 104/104 [00:06<00:00, 15.66it/s, batch_loss=3.14]



Epoch 7 Loss 3.0676


Epoch 8/50: 100%|██████████| 104/104 [00:06<00:00, 16.20it/s, batch_loss=3.13]



Epoch 8 Loss 2.8996


Epoch 9/50: 100%|██████████| 104/104 [00:06<00:00, 16.08it/s, batch_loss=2.93]



Epoch 9 Loss 2.7567


Epoch 10/50: 100%|██████████| 104/104 [00:06<00:00, 16.43it/s, batch_loss=2.49]



Epoch 10 Loss 2.6325


Epoch 11/50: 100%|██████████| 104/104 [00:06<00:00, 16.17it/s, batch_loss=2.54]



Epoch 11 Loss 2.5258


Epoch 12/50: 100%|██████████| 104/104 [00:06<00:00, 16.40it/s, batch_loss=2.21]



Epoch 12 Loss 2.4324


Epoch 13/50: 100%|██████████| 104/104 [00:06<00:00, 16.14it/s, batch_loss=2.55]



Epoch 13 Loss 2.3419


Epoch 14/50: 100%|██████████| 104/104 [00:06<00:00, 16.30it/s, batch_loss=2.41]



Epoch 14 Loss 2.2697


Epoch 15/50: 100%|██████████| 104/104 [00:06<00:00, 15.96it/s, batch_loss=2.13]



Epoch 15 Loss 2.1923


Epoch 16/50: 100%|██████████| 104/104 [00:06<00:00, 16.32it/s, batch_loss=2.45]



Epoch 16 Loss 2.1269


Epoch 17/50: 100%|██████████| 104/104 [00:06<00:00, 16.02it/s, batch_loss=2.12]



Epoch 17 Loss 2.0662


Epoch 18/50: 100%|██████████| 104/104 [00:06<00:00, 16.33it/s, batch_loss=1.9]



Epoch 18 Loss 2.0078


Epoch 19/50: 100%|██████████| 104/104 [00:06<00:00, 15.98it/s, batch_loss=1.9]



Epoch 19 Loss 1.9584


Epoch 20/50: 100%|██████████| 104/104 [00:06<00:00, 16.37it/s, batch_loss=1.64]



Epoch 20 Loss 1.9006


Epoch 21/50: 100%|██████████| 104/104 [00:06<00:00, 16.18it/s, batch_loss=1.83]



Epoch 21 Loss 1.8515


Epoch 22/50: 100%|██████████| 104/104 [00:06<00:00, 16.10it/s, batch_loss=1.96]



Epoch 22 Loss 1.8061


Epoch 23/50: 100%|██████████| 104/104 [00:06<00:00, 16.23it/s, batch_loss=1.71]



Epoch 23 Loss 1.7650


Epoch 24/50: 100%|██████████| 104/104 [00:06<00:00, 15.89it/s, batch_loss=1.75]



Epoch 24 Loss 1.7199


Epoch 25/50: 100%|██████████| 104/104 [00:06<00:00, 16.33it/s, batch_loss=2.12]



Epoch 25 Loss 1.6786


Epoch 26/50: 100%|██████████| 104/104 [00:06<00:00, 15.80it/s, batch_loss=1.63]



Epoch 26 Loss 1.6352


Epoch 27/50: 100%|██████████| 104/104 [00:06<00:00, 16.10it/s, batch_loss=1.94]



Epoch 27 Loss 1.5950


Epoch 28/50: 100%|██████████| 104/104 [00:06<00:00, 15.83it/s, batch_loss=1.62]



Epoch 28 Loss 1.5596


Epoch 29/50: 100%|██████████| 104/104 [00:06<00:00, 16.31it/s, batch_loss=1.62]



Epoch 29 Loss 1.5258


Epoch 30/50: 100%|██████████| 104/104 [00:06<00:00, 15.84it/s, batch_loss=1.39]



Epoch 30 Loss 1.4911


Epoch 31/50: 100%|██████████| 104/104 [00:06<00:00, 16.62it/s, batch_loss=1.69]



Epoch 31 Loss 1.4533


Epoch 32/50: 100%|██████████| 104/104 [00:06<00:00, 16.07it/s, batch_loss=1.59]



Epoch 32 Loss 1.4246


Epoch 33/50: 100%|██████████| 104/104 [00:06<00:00, 16.21it/s, batch_loss=1.45]



Epoch 33 Loss 1.3881


Epoch 34/50: 100%|██████████| 104/104 [00:06<00:00, 16.07it/s, batch_loss=0.978]



Epoch 34 Loss 1.3559


Epoch 35/50: 100%|██████████| 104/104 [00:06<00:00, 16.26it/s, batch_loss=1.58]



Epoch 35 Loss 1.3255


Epoch 36/50: 100%|██████████| 104/104 [00:06<00:00, 15.98it/s, batch_loss=1.56]



Epoch 36 Loss 1.2984


Epoch 37/50: 100%|██████████| 104/104 [00:06<00:00, 16.19it/s, batch_loss=1.11]



Epoch 37 Loss 1.2698


Epoch 38/50: 100%|██████████| 104/104 [00:06<00:00, 16.37it/s, batch_loss=1.31]



Epoch 38 Loss 1.2362


Epoch 39/50: 100%|██████████| 104/104 [00:06<00:00, 15.94it/s, batch_loss=1.24]



Epoch 39 Loss 1.2112


Epoch 40/50: 100%|██████████| 104/104 [00:06<00:00, 16.44it/s, batch_loss=0.951]



Epoch 40 Loss 1.1784


Epoch 41/50: 100%|██████████| 104/104 [00:06<00:00, 16.05it/s, batch_loss=0.956]



Epoch 41 Loss 1.1507


Epoch 42/50: 100%|██████████| 104/104 [00:06<00:00, 16.06it/s, batch_loss=0.707]



Epoch 42 Loss 1.1286


Epoch 43/50: 100%|██████████| 104/104 [00:06<00:00, 15.93it/s, batch_loss=0.845]



Epoch 43 Loss 1.0938


Epoch 44/50: 100%|██████████| 104/104 [00:06<00:00, 16.24it/s, batch_loss=0.892]



Epoch 44 Loss 1.0777


Epoch 45/50: 100%|██████████| 104/104 [00:06<00:00, 16.09it/s, batch_loss=1.3]



Epoch 45 Loss 1.0488


Epoch 46/50: 100%|██████████| 104/104 [00:06<00:00, 16.45it/s, batch_loss=0.921]



Epoch 46 Loss 1.0210


Epoch 47/50: 100%|██████████| 104/104 [00:06<00:00, 16.14it/s, batch_loss=1.12]



Epoch 47 Loss 0.9957


Epoch 48/50: 100%|██████████| 104/104 [00:06<00:00, 16.19it/s, batch_loss=0.81]



Epoch 48 Loss 0.9714


Epoch 49/50: 100%|██████████| 104/104 [00:06<00:00, 15.99it/s, batch_loss=0.789]



Epoch 49 Loss 0.9480


Epoch 50/50: 100%|██████████| 104/104 [00:06<00:00, 16.13it/s, batch_loss=0.992]



Epoch 50 Loss 0.9276
ENG: Borderline cardiomegaly. Midline sternotomy noted.
HIN: स्टर्नोटॉमी नोट किया गया।सिवनी स्टर्नोटॉमी के बाद कार्डियोमेगाली नोट किया गया है।दोनों फेफड़ों की स्थिति।स्थिर स्टर्नोटॉमी नोट किया गया।

ENG: The cardiac silhouette and mediastinum size are within normal limits.
HIN: हृदय सिल्हूट और मीडियास्टिनम सामान्य सीमा के भीतर हैं।कोई तीव्र हड्डी संबंधी असामान्यताएं नहीं हैं।

Avg sentence BLEU (sample): 0.6437904150165957
Training complete.
