In [3]:
from PIL import Image
from torchvision import transforms
import torch
import timm
import numpy as np
from typing import List
import torch.nn as nn
import random
import glob
import re

### VARIABLES ###

# For Pre-processing
custom_transform =  transforms.Compose([
                        transforms.Resize((224, 224)),
                        transforms.ToTensor(),
                    ])

# Model
class Network(nn.Module):
    def __init__(self, model_name, p=0.25):
        super(Network, self).__init__()
        self.model = timm.create_model(model_name, pretrained=True, num_classes=3, drop_rate=0.2)

    def forward(self, img):
        final = self.model(img)
        return final

# For Post-processing
FIRST_ELEM_MAX = 6
FIRST_ELEM_STEP_SIZE = 0.5

SECOND_ELEM_MAX = 1
SECOND_ELEM_STEP_SIZE = 0.1

THIRD_ELEM_MAX = 1
THIRD_ELEM_STEP_SIZE = 0.1

### END OF VARIABLES###

def fix_everything(random_seed):
    torch.manual_seed(random_seed)
    torch.cuda.manual_seed(random_seed)
    torch.cuda.manual_seed_all(random_seed) # if use multi-GPU
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True
    np.random.seed(random_seed)
    random.seed(random_seed)

def compute(img_paths: List[str], model_ckpt: str, model_name: str):
    
    # Set device
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    # Get Model
    model = get_model(model_ckpt, model_name).to(device)

    # Get Image
    img = get_image_tensors(img_paths).to(device)

    # Feed Image into Model
    model.eval()

    with torch.no_grad():
        out = model(img).detach().cpu().numpy()

    return post_process(out)

def get_image_tensors(img_paths: List[str]):
    return torch.cat([custom_transform(Image.open(img_path).convert("RGB")).unsqueeze(0) for img_path in img_paths], dim=0) 

def get_model(model_ckpt: str, model_name:str):
    
    model = Network(model_name)

    try:
        model.load_state_dict(torch.load(model_ckpt))
    except Exception as e:
        keys = torch.load(model_ckpt)
        new_keys = {k[7:]: v for k, v in keys.items()}
        model.load_state_dict(new_keys)
        print(e)

    return model

def post_process(pred):
    pred[pred > 1] = 1
    pred[pred < 0]= 0

    # Since first element minimum is 1:
    pred[:, 0] = np.round(pred[:, 0] * FIRST_ELEM_MAX / FIRST_ELEM_STEP_SIZE ) * FIRST_ELEM_STEP_SIZE
    pred[:, 0][pred[:, 0] < 1] = 1

    pred[:, 1] = np.round(pred[:, 1] * SECOND_ELEM_MAX / SECOND_ELEM_STEP_SIZE) * SECOND_ELEM_STEP_SIZE
    pred[:, 2] = np.round(pred[:, 2] * THIRD_ELEM_MAX / THIRD_ELEM_STEP_SIZE) * THIRD_ELEM_STEP_SIZE
    
    return [{
            "width": pred[idx][0].item(), 
            "dynamics": pred[idx][1].item(),
            "jitter": pred[idx][2].item()
            } for idx in range(len(pred))]    

def get_WDJ_values(img_paths):
    """
    img_paths is a list of image filenames - but image filename should have a specific formatting.
    i.e. one element could be './data/LineStyleData_0/w1.5d0.1j0.1png
    """

    values = []

    for img_title in img_paths:
        matches = re.findall(r'[0-9]+\.?[0-9]*', img_title.split("/")[-1])
        label = [float(x) for x in matches]
        values.append(label)
    
    return values

def get_nrmse(pred, gt, idx):
    pred = pred[:, idx]
    gt = gt[:, idx]
    numerator = np.sqrt(np.mean((gt - pred)**2))
    denominator = np.max(gt) - np.min(gt)
    return numerator / denominator

def get_rsquared(pred, gt, idx):
    pred = pred[:, idx]
    gt =  gt[:, idx]
    return 1 - sum(np.square(gt- pred)) / sum(np.square(gt - np.mean(gt)))
    

def random_eval(path, k=100):
    '''
    A. path: points to a directory where:
     - path
        - sub_path1
            - sub_path1_img1.png
            - sub_path1_img2.png
            - ...
        - sub_path2
            - sub_path2_img1.png
            - sub_path2_img2/png
            - ...

    B. k: number of samples to randomly evaluate.
    '''
    all_imgs = glob.glob(f"{path}/*/*.png")
    all_imgs = [ x for x in all_imgs if '21' not in x and '22' not in x] # If using only training files. Comment this out if necessary.
    chosen_imgs = random.sample(all_imgs, k = 100)

    MODEL_NAME = "regnetv_064" # Fixed for now.
    model_ckpt = "/workspace/LSD/result/FINAL_MODEL.pth" # change this. this model was trained on 0~20.
    result = compute(img_paths=chosen_imgs, model_ckpt=model_ckpt, model_name=MODEL_NAME) # prediction
    
    # Change to numpy array before further calculation.
    result = np.array([[x["width"], x["dynamics"], x["jitter"]] for x in result])
    answer = np.array(get_WDJ_values(chosen_imgs)) # answer (ground truth)

    for idx, x in enumerate(["width", "dynamics", "jitter"]):
        print(f"<Result for {x}>")
        nrmse = get_nrmse(result, answer, idx)
        rsq = get_rsquared(result, answer, idx)
        print(f"NRMSE: {nrmse}")
        print(f"R-Squared: {rsq}")
        print()

In [4]:
# USAGE
fix_everything(42)
data_path = "/workspace/LSD/data"

for i in range(10):
    print("Trial", i+1, ":")
    random_eval(data_path)
    print("---" * 20)
    print()

Trial 1 :
<Result for width>
NRMSE: 0.07810249675906654
R-Squared: 0.9370459048877147
<Result for dynamics>
NRMSE: 0.06708203668160778
R-Squared: 0.9517472882472193
<Result for jitter>
NRMSE: 0.05385164350567124
R-Squared: 0.971737927627039
------------------------------------------------------------

Trial 2 :
<Result for width>
NRMSE: 0.07745966692414834
R-Squared: 0.9313689604685212
<Result for dynamics>
NRMSE: 0.06633249210049212
R-Squared: 0.9568504819187627
<Result for jitter>
NRMSE: 0.058309515702934606
R-Squared: 0.9700374565189621
------------------------------------------------------------

Trial 3 :
<Result for width>
NRMSE: 0.08124038404635961
R-Squared: 0.9305051015573175
<Result for dynamics>
NRMSE: 0.06403123955845387
R-Squared: 0.9554981532883382
<Result for jitter>
NRMSE: 0.052915025460958796
R-Squared: 0.9742173119748258
------------------------------------------------------------

Trial 4 :
<Result for width>
NRMSE: 0.07810249675906654
R-Squared: 0.9305009627325654
<