In [71]:
import torch
import os
import pandas as pd
from agent import AgentGroup
from predictor.dataset import PredictorDataset
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import PIL.Image as Im

In [15]:
def read_data(agent_name, subset="validation")
    df = pd.read_csv(f"./results/agents/baseline/{agent_name}-predictions-on-{subset}}-set.csv")
    df['pred_label'] = df.proba_mean.apply(lambda x: round(x))
    df['success'] = df.apply(lambda row: int(row['label'] == row['pred_label']), axis=1)
    return df.success.mean()

Unnamed: 0,image_id,label,proba_0,proba_1,proba_2,proba_mean,proba_var,pred_label,success
0,agent_one-033740,1,0.99999,0.99999,0.99999,0.999992,1.97069e-11,1,1
1,agent_one-001068,0,0.000719,0.000719,0.000719,0.000539,9.689039e-08,0,1
2,agent_one-008312,0,0.000896,0.000896,0.000896,0.000672,1.504076e-07,0,1
3,agent_one-031938,1,0.997802,0.997802,0.997802,0.998351,9.058841e-07,1,1
4,agent_one-008622,0,0.001044,0.001044,0.001044,0.000783,2.042519e-07,0,1


In [16]:
df.success.mean()

0.852765

In [2]:
group = AgentGroup()

In [41]:
[os.path.join(agent.params['dataset_dir'], "validation-set") for agent in group.agents]

['/data/mhassan/xray/dataset-1/dataset/validation-set',
 '/data/mhassan/xray/dataset-2/dataset/validation-set',
 '/data/mhassan/xray/dataset-3/dataset/validation-set',
 '/data/mhassan/xray/dataset-4/dataset/validation-set']

In [78]:
class PDataset(Dataset):
    def __init__(self, agent_name, transform, dataset="validation", sample_count = None):
        self.agent_name = agent_name
        self.transform = transform
        agent_group = AgentGroup()
        self.agent = agent_group.get_agent(agent_name)
        self.img_dirs = dict(zip(
            [agent.name for agent in agent_group.agents],
            [os.path.join(agent.params['dataset_dir'], "validation-set") for agent in agent_group.agents]            
        ))
        self.data = pd.read_csv(
            os.path.join(self.agent.model_dir, f"{self.agent.name}-predictions-on-{dataset}-set.csv"),
            nrows=sample_count
        )
        self.data['pred_label'] = self.data.proba_mean.apply(lambda x: round(x))
        self.data['success'] = self.data.apply(lambda row: int(row['label'] == row['pred_label']), axis=1)
        self.data['dataset_name'] = self.data.image_id.apply(lambda x: x.split('-')[0])
        self.data['image_id'] = self.data.image_id.copy().apply(lambda x: x.split('-')[1])
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        img = Im.open(
            os.path.join(
                self.img_dirs[row['dataset_name']],
                f"{row['image_id']}.png"
            )
        ).convert('RGB')
        img = self.transform(img)

        return '-'.join([row['dataset_name'], row['image_id']]), img, torch.tensor(row['success'], dtype=torch.float)


In [79]:
dataset = PDataset("agent_one", transforms.ToTensor(), sample_count=100)
dataset.data.head()

Unnamed: 0,image_id,label,proba_0,proba_1,proba_2,proba_mean,proba_var,pred_label,success,dataset_name
0,33740,1,0.99999,0.99999,0.99999,0.999992,1.97069e-11,1,1,agent_one
1,1068,0,0.000719,0.000719,0.000719,0.000539,9.689039e-08,0,1,agent_one
2,8312,0,0.000896,0.000896,0.000896,0.000672,1.504076e-07,0,1,agent_one
3,31938,1,0.997802,0.997802,0.997802,0.998351,9.058841e-07,1,1,agent_one
4,8622,0,0.001044,0.001044,0.001044,0.000783,2.042519e-07,0,1,agent_one


In [80]:
loader = DataLoader(dataset, batch_size = 16, shuffle=True, num_workers = 8)

for idx, images, labels in loader:
    print(images.size(), labels.size())


torch.Size([16, 3, 500, 500]) torch.Size([16])
torch.Size([16, 3, 500, 500]) torch.Size([16])
torch.Size([16, 3, 500, 500]) torch.Size([16])
torch.Size([16, 3, 500, 500]) torch.Size([16])
torch.Size([16, 3, 500, 500]) torch.Size([16])
torch.Size([16, 3, 500, 500]) torch.Size([16])
torch.Size([4, 3, 500, 500]) torch.Size([4])
