In [1]:
import sys
import os
sys.path.append(os.path.abspath(".."))

In [7]:
%load_ext autoreload
%autoreload 2

import warnings
import torch
import utils.CTClipInference
from monai.utils import ensure_tuple_rep
from models.ctclip import CTCLIP
from utils.ctvit import CTViT
from transformers import BertTokenizer, BertModel
from transformers.utils import logging
from torch import nn

warnings.simplefilter("ignore")
logging.set_verbosity_error()
torch.set_printoptions(profile="default")
torch.autograd.set_detect_anomaly(False)

tokenizer = BertTokenizer.from_pretrained('microsoft/BiomedVLP-CXR-BERT-specialized', do_lower_case=True)
text_encoder = BertModel.from_pretrained("microsoft/BiomedVLP-CXR-BERT-specialized")
text_encoder.resize_token_embeddings(len(tokenizer))

dim_latent = 512
dim_text = 768
vit_dim_image = 294912

vit_encoder = CTViT(
    dim = 512,
    codebook_size = 8192,
    image_size = 480,
    patch_size = 20,
    temporal_patch_size = 10,
    spatial_depth = 4,
    temporal_depth = 4,
    dim_head = 32,
    heads = 8
)

clip = CTCLIP(
    text_encoder = text_encoder,
    image_encoder = vit_encoder,
    dim_text = dim_text,
    dim_image = vit_dim_image,
    dim_latent = dim_latent
)

clip.load("/project/project_465001111/ct_clip/pretrained_models/ctclip_v2.pt")

inference = utils.CTClipInference.CTClipInference(
    clip,
    valid_reports = "/project/project_465001111/ct_clip/CT-CLIP-UT/reports/valid_reports.csv",
    data_valid = "/scratch/project_465001111/ct_clip/data_volumes/dataset/valid",
    valid_labels = "/project/project_465001111/ct_clip/CT-CLIP-UT/labels/valid_labels.csv",
    valid_metadata = "/project/project_465001111/ct_clip/CT-CLIP-UT/metadata/valid_metadata.csv",
    results_folder = "/project/project_465001111/ct_clip/CT-CLIP-UT/src/results/valid/ctclip",
    tokenizer = tokenizer,
    batch_size = 1,
    num_workers = 4,
    num_valid_samples = 1,
    zero_shot = False,
    visualize = True
)

inference.infer()

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
Successfully loaded state dictionary from: /project/project_465001111/ct_clip/pretrained_models/ctclip_v2.pt
Validation size: 1
Evaluation started
occlusion visualization started.
[Rank 0] Total patches to go through: 6912
[Rank 0] Patch 1/6912 (0.01%) - Elapsed: 0.1s - ETA: 696.6s
[Rank 0] Patch 101/6912 (1.46%) - Elapsed: 10.0s - ETA: 676.2s
[Rank 0] Patch 201/6912 (2.91%) - Elapsed: 20.0s - ETA: 666.6s
[Rank 0] Patch 301/6912 (4.35%) - Elapsed: 29.9s - ETA: 656.9s
[Rank 0] Patch 401/6912 (5.80%) - Elapsed: 39.8s - ETA: 647.0s
[Rank 0] Patch 501/6912 (7.25%) - Elapsed: 49.8s - ETA: 637.1s
[Rank 0] Patch 601/6912 (8.70%) - Elapsed: 59.7s - ETA: 627.2s
[Rank 0] Patch 701/6912 (10.14%) - Elapsed: 69.7s - ETA: 617.3s
[Rank 0] Patch 801/6912 (11.59%) - Elapsed: 79.6s - ETA: 607.4s
[Rank 0] Patch 901/6912 (13.04%) - Elapsed: 89.6s - ETA: 597.5s
[Rank 0] Patch 1001/6912 (14.48%) - Elapsed: 99.5s - ETA: 5