In [9]:
import sys
import os
sys.path.append(os.path.abspath(".."))

In [10]:
%load_ext autoreload
%autoreload 2

import warnings
import torch
from models.ctgenerate import CTGENERATE
from utils.CTGenerateInference import CTGenerateInference
from utils.ctvit import CTViT
from utils.maskgit import MaskGit
from utils.t5 import T5Encoder

warnings.simplefilter("ignore")
torch.set_printoptions(profile="default")
torch.autograd.set_detect_anomaly(False)

ctvit = CTViT(
    dim = 512,
    codebook_size = 8192,
    image_size = 128,
    patch_size = 16,
    temporal_patch_size = 2,
    spatial_depth = 4,
    temporal_depth = 4,
    dim_head = 32,
    heads = 8,
    model_type = "ctgenerate"
)

maskgit = MaskGit(
    num_tokens=8192,
    max_seq_len=10000,
    dim=512,
    dim_context=768,
    depth=6,
)

t5 = T5Encoder()

ctgenerate = CTGENERATE(
    ctvit=ctvit,
    maskgit=maskgit,
    t5=t5
)
ctgenerate.load("/project/project_465001111/ct_clip/pretrained_models/ctgenerate_filtered.pt", strict=False)

inference = CTGenerateInference(
    ctgenerate,
    valid_reports = "/project/project_465001111/ct_clip/CT-CLIP-UT/reports/valid_reports.csv",
    data_valid = "/scratch/project_465001111/ct_clip/data_volumes/dataset/valid",
    valid_labels = "/project/project_465001111/ct_clip/CT-CLIP-UT/labels/valid_labels.csv",
    valid_metadata = "/project/project_465001111/ct_clip/CT-CLIP-UT/metadata/valid_metadata.csv",
    results_folder = "/project/project_465001111/ct_clip/CT-CLIP-UT/src/results/valid/ctgenerate",
    batch_size = 1,
    num_workers = 4,
    num_valid_samples = 10
)

inference.infer()


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
Validation size: 10
CTGENERATE inference started
Computing cross_attention for ('valid_852_a_2',) with pathology: Pleural effusion
cross_attention shape torch.Size([1, 8, 6464, 4])
Computing cross_attention for ('valid_183_a_1',) with pathology: Hiatal hernia
cross_attention shape torch.Size([1, 8, 6464, 7])
Computing cross_attention for ('valid_183_a_2',) with pathology: Hiatal hernia
cross_attention shape torch.Size([1, 8, 6464, 7])
Computing cross_attention for ('valid_852_a_1',) with pathology: Pleural effusion
cross_attention shape torch.Size([1, 8, 6464, 4])
CTGENERATE inference completed. Total inference Time: 0:03:10.080920
