In [None]:
!pip install kaggle



In [None]:
! mkdir ~/.kaggle
! cp kaggle.json ~/.kaggle/
! chmod 600 ~/.kaggle/kaggle.json

In [None]:
!kaggle datasets download -d raddar/chest-xrays-indiana-university

Downloading chest-xrays-indiana-university.zip to /content
100% 13.2G/13.2G [07:40<00:00, 31.4MB/s]
100% 13.2G/13.2G [07:40<00:00, 30.7MB/s]


In [None]:
!unzip /content/chest-xrays-indiana-university.zip

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
  inflating: images/images_normalized/219_IM-0799-1001.dcm.png  
  inflating: images/images_normalized/219_IM-0799-2001.dcm.png  
  inflating: images/images_normalized/21_IM-0729-1001-0001.dcm.png  
  inflating: images/images_normalized/21_IM-0729-1001-0002.dcm.png  
  inflating: images/images_normalized/2200_IM-0811-1001.dcm.png  
  inflating: images/images_normalized/2200_IM-0811-2001.dcm.png  
  inflating: images/images_normalized/2201_IM-0811-1002.dcm.png  
  inflating: images/images_normalized/2202_IM-0811-1001.dcm.png  
  inflating: images/images_normalized/2202_IM-0811-1002.dcm.png  
  inflating: images/images_normalized/2203_IM-0812-1001.dcm.png  
  inflating: images/images_normalized/2203_IM-0812-2001.dcm.png  
  inflating: images/images_normalized/2204_IM-0813-1001.dcm.png  
  inflating: images/images_normalized/2204_IM-0813-1002.dcm.png  
  inflating: images/images_normalized/2205_IM-0814-1001.dcm.png  
  infla

In [None]:
!git clone https://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@github.com/falco-tigris/CS577_DeepLearning_Project.git

Cloning into 'CS577_DeepLearning_Project'...
remote: Enumerating objects: 66, done.[K
remote: Counting objects: 100% (66/66), done.[K
remote: Compressing objects: 100% (45/45), done.[K
remote: Total 66 (delta 22), reused 55 (delta 15), pack-reused 0[K
Receiving objects: 100% (66/66), 457.27 KiB | 9.94 MiB/s, done.
Resolving deltas: 100% (22/22), done.


In [None]:
!cp -r /content/CS577_DeepLearning_Project/* /content/

In [None]:
!cp /content/indiana_projections.csv /content/data
!cp /content/indiana_reports.csv /content/data
!cp -r /content/images /content/data

# Pretrained CNN + Transformer decoder

In [None]:
import math
import torch
import numpy as np
from tqdm import tqdm
import torch.nn as nn
from torch.nn import functional as F
from torchtext.data.metrics import bleu_score
from torch.utils.data.dataloader import DataLoader

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class ImageEncoderReportDecoderConfig:

    def __init__(self, vocab_size, block_size, n_embd, pretrain, train_decoder, pretrained_encoder_model):
        self.vocab_size = vocab_size
        self.block_size = block_size
        self.n_embd = n_embd
        self.pretrain = pretrain
        self.train_decoder = train_decoder
        self.pretrained_encoder_model = pretrained_encoder_model

class ImageEncoderReportDecoder(nn.Module):

    def __init__(self, config, img_enc, img_enc_out_shape, img_enc_name="ResNet18"):
        super().__init__()

        self.cnf = config
        self.img_enc_name = img_enc_name

        self.img_enc = img_enc
        self.img_enc_linear = nn.Linear(img_enc_out_shape[1], config.n_embd)

        self.pos_emb = nn.Parameter(torch.zeros(1, config.block_size, config.n_embd))
        self.drop = nn.Dropout(0.3)

        encoder_layer = nn.TransformerEncoderLayer(d_model=config.n_embd, nhead=33, dim_feedforward=2048, dropout=0.3, activation='gelu', batch_first=True, norm_first=True)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=1)


        self.contrastive_head = nn.Linear(config.n_embd, 1)

        self.tgt_emb = nn.Embedding(config.vocab_size, config.n_embd)
        self.tgt_drop = nn.Dropout(0.3)

        decoder_layer = nn.TransformerDecoderLayer(d_model=config.n_embd, nhead=33, dim_feedforward=2048, dropout=0.3, activation='gelu', batch_first=True, norm_first=True)
        self.transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=2)

        self.dcd_proj = nn.Linear(config.n_embd, config.vocab_size)

        self.block_size = config.block_size
        self.apply(self._init_weights)

        self.register_buffer("mask", torch.triu(torch.ones(config.block_size, config.block_size)*float('-inf'),diagonal=1))
        self.mask = self.mask < 0

    def get_block_size(self):
        return self.block_size

    def _init_weights(self, module):
        if isinstance(module, (nn.Linear, nn.Embedding)):
            module.weight.data.normal_(mean=0.0, std=0.02)
            if isinstance(module, nn.Linear) and module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)

    def forward(self, x, targets, len_mask):

        loss = None
        logit = None
        prediction = None


        reps = self.representation(x)
        logit, loss, prediction = self.decode(reps, targets, len_mask)

        return logit, loss, prediction

    def configure_optimizers(self, train_config):

        optimizer = torch.optim.AdamW(self.parameters(), lr=train_config.learning_rate, betas=train_config.betas)
        return optimizer

    def get_block_size(self):
        return self.block_size

    def representation(self, x):
        x_org = x

        if self.img_enc_name == "ResNet18" or self.img_enc_name == "UNet":
            x = torch.cat((x.unsqueeze(1),x.unsqueeze(1),x.unsqueeze(1)), dim=1)
            x_org = x_org.unsqueeze(1)
        elif self.img_enc_name == "ResNetAE":
            x = x.unsqueeze(1)

        with torch.no_grad():
            x = self.img_enc(x)
            if self.img_enc_name == "ResNetAE":
                b,t,e1,e2 = x['z'].shape
                x = x['z'].reshape((b,t,-1))
            elif self.img_enc_name == "UNet":
                pass

        if len(x.shape) < 3:
            x = x.unsqueeze(2)
        elif len(x.shape) > 3:
            x = x.squeeze(1)

        x = self.img_enc_linear(x)
        b, t, _ = x.size()
        x = x + self.pos_emb[:, :t,:]
        x = self.drop(x)
        reps = self.transformer_encoder(x)

        return reps

    def decode(self, reps, targets, len_mask=None):

        tgt_emb = self.tgt_emb(targets)
        _, tt, _ = tgt_emb.size()
        tgt_emb = tgt_emb + self.pos_emb[:, :tt, :]
        tgt_emb = self.tgt_drop(tgt_emb)

        if len_mask != None:
            len_mask = len_mask > 0
            dec_out = self.transformer_decoder(tgt_emb[:,:-1,:], reps, tgt_mask=self.mask[:-1,:-1], tgt_key_padding_mask=len_mask[:,:-1])
        else:
            dec_out = self.transformer_decoder(tgt_emb, reps)

        logits = self.dcd_proj(dec_out)
        loss = 0
        if len_mask != None:
            loss += F.cross_entropy(logits.view(-1,logits.shape[-1]), targets[:,1:].reshape(-1), reduction='none')

        prediction = torch.argmax(logits,dim=-1)
        return logits, loss, prediction

class TrainerConfig:
    # optimization parameters
    max_epochs = 10
    batch_size = 64
    learning_rate = 3e-4
    betas = (0.9, 0.95)
    grad_norm_clip = 1.0
    weight_decay = 0.1 # only applied on matmul weights
    # learning rate decay params: linear warmup followed by cosine decay to 10% of original
    lr_decay = False
    warmup_tokens = 375e6 # these two numbers come from the GPT-3 paper, but may not be good defaults elsewhere
    final_tokens = 260e9 # (at what point we reach 10% of original LR)

    # checkpoint settings
    ckpt_path = None
    num_workers = 0 # for DataLoader
    pretrain = False
    tokenizer = None

    def __init__(self, **kwargs):
        for k,v in kwargs.items():
            setattr(self, k, v)

class Trainer:

    def __init__(self, model, train_dataset, test_dataset, config, word_2_id, id_2_word):
        self.model = model
        self.config = config
        self.train_dataset = train_dataset
        self.test_dataset = test_dataset
        self.id_2_word = id_2_word
        self.word_2_id = word_2_id
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    def save_checkpoint(self):
        # DataParallel wrappers keep raw model object in .module attribute
        raw_model = self.model.module if hasattr(self.model, "module") else self.model
        torch.save(raw_model.state_dict(), self.config.ckpt_path)

    def train(self):
        model, config = self.model, self.config
        raw_model = model.module if hasattr(self.model, "module") else model
        optimizer = raw_model.configure_optimizers(config)

        def run_epoch(split):
            is_train = split == 'train'
            model.train(is_train)
            data = self.train_dataset if is_train else self.test_dataset
            loader = DataLoader(data, shuffle=True, pin_memory=True,
                                batch_size=config.batch_size,
                                num_workers=config.num_workers)

            losses = []
            tgts = []
            preds = []
            pbar = tqdm(enumerate(loader), total=len(loader)) if is_train else enumerate(loader)
            for it, (x, y, len_masks) in pbar:

                # place data on the correct device
                x = x.to(self.device)
                y = y.to(self.device)
                len_masks = len_masks.to(self.device)

                # forward the model
                with torch.set_grad_enabled(is_train):
                    _, loss, pred = model(x, y, len_masks)
                    loss = loss.mean()
                    losses.append(loss.item())
                    tgts.append(y)
                    preds.append(pred)

                if is_train:
                    # backprop and update the parameters
                    model.zero_grad()
                    loss.backward()
                    torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_norm_clip)
                    optimizer.step()

                    # decay the learning rate based on our progress
                    if config.lr_decay:
                        self.tokens += (y >= 0).sum() # number of tokens processed this step (i.e. label is not -100)
                        if self.tokens < config.warmup_tokens:
                            # linear warmup
                            lr_mult = float(self.tokens) / float(max(1, config.warmup_tokens))
                        else:
                            # cosine learning rate decay
                            progress = float(self.tokens - config.warmup_tokens) / float(max(1, config.final_tokens - config.warmup_tokens))
                            lr_mult = max(0.1, 0.5 * (1.0 + math.cos(math.pi * progress)))
                        lr = config.learning_rate * lr_mult
                        for param_group in optimizer.param_groups:
                            param_group['lr'] = lr
                    else:
                        lr = config.learning_rate

                    # report progress
                    pbar.set_description(f"epoch {epoch+1} iter {it}: train loss {loss.item():.5f}. lr {lr:e}")

            if not is_train:
                test_loss = float(np.mean(losses))

                test_bleu = 0
                if not config.pretrain:
                    tgts = torch.vstack(tgts).cpu().numpy().tolist()
                    preds = torch.vstack(preds).cpu().numpy().tolist()
                    tgts_list = []
                    preds_list = []
                    for i in range(len(tgts)):
                        try:
                            eos_ind = tgts[i].index(self.tokenizer.eos_token)
                        except:
                            eos_ind = len(tgts[i])-1


                        tgts_list.append(tgts[i][:eos_ind])
                        tgts_list[-1] = [[self.id_2_word[x] for x in tgts_list[-1]]]

                    for i in range(len(preds)):
                        try:
                            eos_ind = preds[i].index(self.tokenizer.eos_token)
                        except:
                            eos_ind = len(preds[i])-1

                        preds_list.append(preds[i][:eos_ind])
                        preds_list[-1] = [str(self.id_2_word[x]) for x in preds_list[-1]]

                    assert(len(preds_list) == len(tgts_list))

                    test_bleu = bleu_score(preds_list, tgts_list, max_n=4, weights=[0.5, 0.5, 0, 0])
                    print("bleu_score_2: ", test_bleu)

                print("test loss: ", test_loss)

                return test_loss, test_bleu

            return float(np.mean(losses)), None

        best_loss = float('inf')
        best_bleu = float('-inf')

        self.tokens = 0

        for epoch in range(config.max_epochs):

            train_loss, _ = run_epoch('train')
            if self.test_dataset is not None:
                test_loss, test_bleu = run_epoch('test')

            print("Train: ", train_loss, "Test: ", test_loss)

            # supports early stopping based on the test loss, or just save always if no test set is provided
            good_model = self.test_dataset is None or test_loss < 1.05*best_loss or test_bleu > best_bleu
            if self.config.ckpt_path is not None and good_model:
                if test_loss < best_loss:
                    best_loss = test_loss
                if test_bleu > best_bleu:
                    best_bleu = test_bleu
                self.save_checkpoint()


## Data

In [None]:
import torch
import numpy as np
import pandas as pd
from utils import set_seed
import models.training as tr
import data.preprocessing as pr
from torchvision import transforms
from matplotlib import pyplot as plt
from tokenizer.tokenizer import Tokenizer
from collections import OrderedDict

#########################
set_seed(33)
#use model as densenet121 for densenet
# Pretrained CNN models
img_enc_resnet = torch.hub.load('pytorch/vision:v0.8.0', 'resnet18', pretrained=False)

#uncomment the below lines for densenet
# state_dict = torch.load('/content/m-30012020-104001.pth.tar')
# new_state_dict = OrderedDict()
# for k, v in state_dict.items():
#     if 'denseblock' in k:
#         param = k.split(".")
#         k = ".".join(param[:-3] + [param[-3]+param[-2]] + [param[-1]])
#         new_state_dict[k] = v
# img_enc_resnet.load_state_dict(new_state_dict,strict=False)
# img_enc_resnet.classifier = torch.nn.Linear(1024, 512)
img_enc_resnet.fc = torch.nn.Identity()

img_enc_resnet.input_shape = (224, 224)
img_enc_resnet.output_shape = (512, 1)

img_enc_name =  "rnet"
img_enc = img_enc_resnet
img_enc_width, img_enc_height = img_enc.input_shape
img_enc_out_shape = img_enc.output_shape
block_size = 512
##########################
reports = pr.load_reports()
tokenizer = Tokenizer(reports, 'word')

# Load data
batch_size = 8
# Load data
uids = np.unique(pr.projections.index)

# Image preprocessing
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((224, 224), antialias=False)
])

train_data, val_data, test_data = pr.create_dataloaders(uids, pr.IMAGES_PATH, max_length=block_size, batch_size=batch_size, transform=transform, tokenizer=tokenizer)
# train_data, train_loader, val_data, val_loader, test_data, test_loader = pr.create_dataloaders(uids, pr.IMAGES_PATH, max_length=block_size, batch_size=batch_size, transform=transform, tokenizer=tokenizer)
# print(f'There are {len(train_data) :,} samples for training, and {len(val_data) :,} samples for validation testing')


[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
Using cache found in /root/.cache/torch/hub/pytorch_vision_v0.8.0


In [None]:
# Create model
vocab_size = tokenizer.get_vocab_size()
n_embd = 528
config = ImageEncoderReportDecoderConfig(vocab_size, block_size, n_embd, True, True, None)
model = ImageEncoderReportDecoder(config, img_enc, img_enc_out_shape, img_enc_name)
model.to(device)

# Train model
epochs = 5
tokens_per_epoch = len(train_data) * block_size

train_config = TrainerConfig(max_epochs=epochs, batch_size=8, learning_rate=1.0e-3,
                          betas = (0.9, 0.95), weight_decay=0, lr_decay=True,
                          warmup_tokens=tokens_per_epoch,
                          final_tokens= epochs*tokens_per_epoch,
                          ckpt_path='reportnet',
                          num_workers=8,
                          pretrain = True)

# trainer = Trainer(model, train_data, val_data, train_config, tokenizer.word2idx, tokenizer.idx2word)
# trainer.train()

# # Save model
# torch.save(model.state_dict(), 'model.pth')



## Predict

In [None]:
import torch

# Create model
vocab_size = tokenizer.get_vocab_size()
config = ImageEncoderReportDecoderConfig(vocab_size, block_size, 528, True, True, None)
model = ImageEncoderReportDecoder(config, img_enc, img_enc_out_shape, img_enc_name)
model.load_state_dict(torch.load('model.pth'))

<All keys matched successfully>

In [None]:
@torch.no_grad()
def sample(model, x, y0, steps):
    """
    take a conditioning sequence of indices in x (of shape (b,t)) and predict the next token in
    the sequence, feeding the predictions back into the model each time. Clearly the sampling
    has quadratic complexity unlike an RNN that is only linear, and has a finite context window
    of block_size, unlike an RNN that has an infinite context window.
    """
    model.eval()
    generated = []

    y1 = y0
    reps = model.representation(x)

    for k in range(steps):
        logits, _, _ = model.decode(reps, y0, None)
        logits = logits[:, -1, :]
        y1 = torch.argmax(logits, dim=1).unsqueeze(0)
        generated.append(y1.item())
        y0 = torch.cat((y0, y1), dim=1)

    return generated

In [None]:
tgts = []
preds = []

print("len val dataset:", len(val_data))
pbar = tqdm(enumerate(val_data))

for it, (x, y, mask) in pbar:
    x = x.unsqueeze(0).to("cuda")
    y = y.unsqueeze(0).unsqueeze(2).to("cuda")

    mask = mask.unsqueeze(0).to("cuda")
    gen = sample(model.to("cuda"), x, y[:, 0, :], steps=30)

    y = [item.item() for sublist in y[0] for item in sublist]
    tgts.append(torch.IntTensor(y))
    preds.append(torch.IntTensor(gen))
    # pr.print_sequence(y, tokenizer.idx2word)
    # pr.print_sequence(gen, tokenizer.idx2word)

tgts = torch.vstack(tgts).cpu().numpy().tolist()
preds = torch.vstack(preds).cpu().numpy().tolist()
tgts_list = []
preds_list = []
for i in range(len(tgts)):
    try:
        eos_ind = tgts[i].index(tokenizer.eos_token)
    except:
        eos_ind = len(tgts[i]) - 1

    tgts_list.append(tgts[i][:eos_ind])
    tgts_list[-1] = [[tokenizer.idx2word[x] for x in tgts_list[-1]]]

for i in range(len(preds)):
    try:
        eos_ind = preds[i].index(tokenizer.eos_token)
    except:
        eos_ind = len(preds[i]) - 1

    preds_list.append(preds[i][:eos_ind])
    # preds_list[-1] = tokenizer.decode(preds_list[-1]).split()
    preds_list[-1] = [str(tokenizer.idx2word[x]) for x in preds_list[-1]]

test_bleu_1 = bleu_score(preds_list, tgts_list, max_n=4, weights=[1, 0, 0, 0])
test_bleu_2 = bleu_score(preds_list, tgts_list, max_n=4, weights=[0.5, 0.5, 0, 0])
test_bleu_3 = bleu_score(preds_list, tgts_list, max_n=4, weights=[0.34, 0.33, 0.33, 0])
test_bleu_4 = bleu_score(preds_list, tgts_list, max_n=4, weights=[0.25, 0.25, 0.25, 0.25])
avg_bleu = (test_bleu_1 + test_bleu_2 + test_bleu_3 + test_bleu_4) / 4
print(f"\nBleu Scores:\nB_1:{test_bleu_1} \nB_2:{test_bleu_2} \nB_3:{test_bleu_3} \nB_4:{test_bleu_4}\n----Avg:{avg_bleu} ----")

len val dataset: 1495


1495it [04:59,  5.00it/s]



Bleu Scores:
B_1:0.17196721770476692 
B_2:0.10390777126579594 
B_3:0.0741463480342335 
B_4:0.056423741096470394
----Avg:0.10161126952531668 ----


In [None]:
tgts = []
preds = []

print("len val dataset:", len(test_data))
pbar = tqdm(enumerate(test_data))

for it, (x, y, mask) in pbar:
    x = x.unsqueeze(0).to("cuda")
    y = y.unsqueeze(0).unsqueeze(2).to("cuda")

    mask = mask.unsqueeze(0).to("cuda")
    gen = sample(model.to("cuda"), x, y[:, 0, :], steps=30)

    y = [item.item() for sublist in y[0] for item in sublist]
    tgts.append(torch.IntTensor(y))
    preds.append(torch.IntTensor(gen))
    # pr.print_sequence(y, tokenizer.idx2word)
    # pr.print_sequence(gen, tokenizer.idx2word)

tgts = torch.vstack(tgts).cpu().numpy().tolist()
preds = torch.vstack(preds).cpu().numpy().tolist()
tgts_list = []
preds_list = []
for i in range(len(tgts)):
    try:
        eos_ind = tgts[i].index(tokenizer.eos_token)
    except:
        eos_ind = len(tgts[i]) - 1

    tgts_list.append(tgts[i][:eos_ind])
    tgts_list[-1] = [[tokenizer.idx2word[x] for x in tgts_list[-1]]]

for i in range(len(preds)):
    try:
        eos_ind = preds[i].index(tokenizer.eos_token)
    except:
        eos_ind = len(preds[i]) - 1

    preds_list.append(preds[i][:eos_ind])
    # preds_list[-1] = tokenizer.decode(preds_list[-1]).split()
    preds_list[-1] = [str(tokenizer.idx2word[x]) for x in preds_list[-1]]

test_bleu_1 = bleu_score(preds_list, tgts_list, max_n=4, weights=[1, 0, 0, 0])
test_bleu_2 = bleu_score(preds_list, tgts_list, max_n=4, weights=[0.5, 0.5, 0, 0])
test_bleu_3 = bleu_score(preds_list, tgts_list, max_n=4, weights=[0.34, 0.33, 0.33, 0])
test_bleu_4 = bleu_score(preds_list, tgts_list, max_n=4, weights=[0.25, 0.25, 0.25, 0.25])
avg_bleu = (test_bleu_1 + test_bleu_2 + test_bleu_3 + test_bleu_4) / 4
print(f"\nBleu Scores:\nB_1:{test_bleu_1} \nB_2:{test_bleu_2} \nB_3:{test_bleu_3} \nB_4:{test_bleu_4}\n----Avg:{avg_bleu} ----")

len val dataset: 1493


1493it [04:59,  4.98it/s]



Bleu Scores:
B_1:0.16962956302209464 
B_2:0.10233284128569753 
B_3:0.0730797575523815 
B_4:0.05593645047671087
----Avg:0.10024465308422113 ----


In [None]:
!pip install gradio

Collecting gradio
  Downloading gradio-4.7.1-py3-none-any.whl (16.5 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m16.5/16.5 MB[0m [31m68.5 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting aiofiles<24.0,>=22.0 (from gradio)
  Downloading aiofiles-23.2.1-py3-none-any.whl (15 kB)
Collecting fastapi (from gradio)
  Downloading fastapi-0.104.1-py3-none-any.whl (92 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m92.9/92.9 kB[0m [31m13.2 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting ffmpy (from gradio)
  Downloading ffmpy-0.3.1.tar.gz (5.5 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting gradio-client==0.7.0 (from gradio)
  Downloading gradio_client-0.7.0-py3-none-any.whl (302 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m302.7/302.7 kB[0m [31m38.5 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting httpx (from gradio)
  Downloading httpx-0.25.2-py3-none-any.whl (74 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [None]:
#restart runtime before running
import gradio as gr

In [None]:
from PIL import Image
import requests

In [None]:
import gradio as gr
import torch
from PIL import Image
import torchvision.transforms as transforms
# Import other necessary libraries and modules

# Load your model
config = ImageEncoderReportDecoderConfig(vocab_size, block_size, n_embd, True, True, None)

model = ImageEncoderReportDecoder(config, img_enc, img_enc_out_shape, img_enc_name)
model.load_state_dict(torch.load('model.pth'))
model = model.to("cuda")

from tokenizer.tokenizer import Tokenizer
import data.preprocessing as pr

# Load your reports data. This depends on how your 'pr.load_reports()' function works.
reports = pr.load_reports()

# Instantiate the tokenizer with these reports
tokenizer = Tokenizer(reports, 'word')




In [None]:
@torch.no_grad()
def sample(model, x, y0, steps):
    model.eval()
    generated = []

    y1 = y0
    reps = model.representation(x)

    for k in range(steps):
        logits, _, _ = model.decode(reps, y0, None)
        logits = logits[:, -1, :]
        y1 = torch.argmax(logits, dim=1).unsqueeze(0)
        generated.append(y1.item())
        y0 = torch.cat((y0, y1), dim=1)

    return generated

# Function to generate text report
def generate_text_report(image, model, tokenizer, steps=30):
    # Image Preprocessing
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor()
    ])
    print(image)
    image = transform(image).unsqueeze(0)  # Add batch dimension

    # Initial token for generation (modify as per your model's requirement)
    initial_token = torch.tensor([tokenizer.word2idx["<SOS>"]]).unsqueeze(0)

    # Generate prediction
    gen = sample(model, image.to("cuda"), initial_token.to("cuda"), steps)

    # Convert prediction to text
    report = [tokenizer.idx2word[idx] for idx in gen if idx in tokenizer.idx2word]
    return ' '.join(report)

# Gradio interface
def gradio_interface(image):
    return generate_text_report(image, model, tokenizer)

iface = gr.Interface(
    fn=gradio_interface,
    inputs=gr.Image(type='pil'),
    outputs="text",
    title="CS577 - Deep Learning Project \n Chest X-Ray Report Generation",
    description="Upload a chest X-ray image to get a report."
)

In [None]:
iface.launch(debug=True)

Setting queue=True in a Colab notebook requires sharing enabled. Setting `share=True` (you can turn this off by setting `share=False` in `launch()` explicitly).

Colab notebook detected. This cell will run indefinitely so that you can see errors and logs. To turn off, set debug=False in launch().
Running on public URL: https://dfc291fbe3f9838b87.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)


<PIL.Image.Image image mode=RGB size=2048x2575 at 0x7B856C3F69E0>
Keyboard interruption in main thread... closing server.
Killing tunnel 127.0.0.1:7860 <> https://dfc291fbe3f9838b87.gradio.live


