# 利用 JobAssigner 與 ThreadPoolExecutor， 以多執行緒方式執行 model predict

In [1]:
from threading_jobs import JobAssigner
from torchvision.models.resnet import resnet18
from torchvision.models.resnet import ResNet18_Weights 
from torchvision.datasets import CIFAR10
import torch
import torchvision.transforms as transforms
import numpy as np
from tqdm import tqdm
from typing import Tuple
from concurrent.futures import ThreadPoolExecutor, as_completed


#### 請在 8 張 GPU 以上的環境執行

In [2]:
assert torch.cuda.device_count() >= 8

#### 下載 CIFAR-10 資料

In [3]:
#download data
_ = CIFAR10("CIFAR10/", download=True, train=False, transform=transforms.ToTensor())

Files already downloaded and verified


#### 建立 predict function 並使用 JobAssigner decorator
* 注意我用 worker_names 定義了四個 worker，每個 worker 用的卡編號與數量都不同

In [4]:
worker_names = [(1,2), (3,), (4,5,6), (7,)]

@JobAssigner(worker_names, max_job_per_worker=2, worker_arg_name='device_ids', if_no_id='raise')
def predict_one_batch(device_ids: Tuple[int]) -> float:
    device = torch.device(f"cuda:{device_ids[0]}")  
    model = resnet18(num_classes=10).to(device)
    if len(device_ids) > 1:
        # use DataParallel
        model = torch.nn.DataParallel(model, device_ids=device_ids)        
    test = CIFAR10("CIFAR10/", download=False, train=False, transform=transforms.ToTensor())
    dataloader = torch.utils.data.DataLoader(test, batch_size=128, shuffle=False)
    x,y = dataloader.__iter__().__next__()
    x,y = x.to(device), y.to(device)
    pred = model(x)
    acc = float(np.mean((torch.argmax(pred, 1) == y).cpu().numpy()))
    return acc

#### ThreadPoolExecutor 執行的邏輯

In [5]:
def run_job_in_threadpool(max_workers):
    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        future_dict = dict()
        result_dict = dict()
        with tqdm(desc="threading_jobs", total=12) as pbar:        
            for index in range(12):
                future = executor.submit(predict_one_batch)
                future_dict[future] = index

            for future in as_completed(future_dict):
                index = future_dict[future]
                result_dict[index] = future.result()
                pbar.update()
        return result_dict
            
            
        


#### 執行與結果

In [6]:
result = run_job_in_threadpool(8)
result

threading_jobs: 100%|█████████████████████████████████████████████████████████████████████████████| 12/12 [00:13<00:00,  1.16s/it]


{5: 0.0546875,
 8: 0.0859375,
 9: 0.0625,
 1: 0.0703125,
 3: 0.0390625,
 7: 0.109375,
 10: 0.1171875,
 11: 0.1171875,
 4: 0.0859375,
 0: 0.1171875,
 6: 0.0859375,
 2: 0.109375}

#### 因為我們設定 if_no_id='raise' 且 max_job_per_worker=2， 等於一次最多只能執行八個 job，所以下面會直接 raise

In [7]:
# since we set if_no_id='raise' in JobAssigner, this will raise:
try:
    _ = run_job_in_threadpool(12)
except ValueError:
    print('error raised!')
else:
    raise Exception('error is not raised!')

threading_jobs:   0%|                                                                                      | 0/12 [00:00<?, ?it/s]


error raised!
