# Visualising Predictions

### 1. Imports and Model Initialization

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
from PIL import Image
from utils import FinalModel, visualize_prediction
from config_io import save_to_config, get_config_value
from transformers import RobertaModel, AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("roberta-base",add_prefix_space=True)
text_encoder = RobertaModel.from_pretrained("roberta-base")

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
choices = ["final_weights.pth","best_weights.pth"]      #Choose
i = 0
model_save_weight_path = None
ckpt_dir  = Path(get_config_value("CHECKPOINTS_PATH", default="not-set"))
candidate = ckpt_dir / choices[i]

if not candidate.is_file():                    # ← real existence test
    print(f"Please provide the link for {choices[i]}")
    model_save_weight_path = Path("/content/drive/MyDrive/refcoco_project/Weights_Checkpoint/best_weights.pth")
    save_to_config({"CHECKPOINTS_PATH": str(model_save_weight_path.parent)})
else:
    model_save_weight_path = candidate

In [None]:
model = FinalModel()
state_dict = torch.load(model_save_weight_path)
model.load_state_dict(state_dict)
model.to(device)
model.eval()

## 3. Evaluate

In [None]:
phrase = ""
filename = ""


base_dir      = Path(get_config_value("OUT_DIR", default="not-set"))
img_path      = base_dir / filename

if not img_path.is_file():
    new_dir = Path(
        input(
            f"File '{img_path}' not found.\n"
            f"Enter the directory that contains '{filename}': "
        ).strip()
    ).expanduser().resolve()

    # sanity-check the new directory
    if not new_dir.is_dir():
        raise FileNotFoundError(f"‘{new_dir}’ is not a directory.")
    
    save_to_config({"TEST_IMG_DIR": str(new_dir)})
    
    img_path = new_dir / filename
    if not img_path.is_file():
        raise FileNotFoundError(f"‘{img_path}’ still does not exist.")

In [None]:
#Run this cell to see the image before predicted boxes
img = Image.open(img_path).convert("RGB")
img_np = np.array(img)
plt.imshow(img_np)
plt.axis('off')
plt.title("The image")
plt.show()

In [None]:
visualize_prediction(model,img_path,phrase,tokenizer,device)