#Setup

In [None]:
!pip install wandb

In [None]:
%cd /content
!git clone https://github.com/ethan-jiang-1/pim_classification.git pimcls 
%cd /content/pimcls
!git checkout master
!git pull origin master
!git submodule update --init --recursive
%cd /content/pimcls/pim
!pip install -e .
%cd /content/pimcls

#Download dataset - omniglot

In [None]:
import os
import time
from IPython.display import clear_output
import torchvision 
import shutil
from sklearn.model_selection import train_test_split


def make_ds_on(char_folder, ds_folder):
    os.makedirs(ds_folder, exist_ok=True)
    fds = os.listdir(char_folder)
    for fd in fds:
        fdr = "{}/{}".format(char_folder, fd)
        if not os.path.isdir(fdr):
            continue
        fls = os.listdir(fdr)
        #print(fdr, fls)
        train_fls, val_fls = train_test_split(fls, test_size=0.4)
        #print(train_fls)
        #print(val_fls)
        ds_train = "{}/train/{}".format(ds_folder, fd)
        os.makedirs(ds_train, exist_ok=True)
        for fl in train_fls:
            src = "{}/{}".format(fdr, fl)
            dst = "{}/{}".format(ds_train, fl)
            shutil.copy(src, dst)

        val_fls, test_fls = train_test_split(val_fls, test_size=0.5)
        ds_val = "{}/val/{}".format(ds_folder, fd)
        os.makedirs(ds_val, exist_ok=True)
        for fl in val_fls:
            src = "{}/{}".format(fdr, fl)
            dst = "{}/{}".format(ds_val, fl)
            shutil.copy(src, dst)

        ds_test = "{}/test/{}".format(ds_folder, fd)
        os.makedirs(ds_test, exist_ok=True)
        for fl in test_fls:
            src = "{}/{}".format(fdr, fl)
            dst = "{}/{}".format(ds_test, fl)
            shutil.copy(src, dst)

%cd /content/pimcls

if not os.path.isdir("omniglot_src"):
    ds = torchvision.datasets.Omniglot("omniglot_src", download=True)
    ds.download()

src_folder = "/content/pimcls/omniglot_src/omniglot-py/images_background/Burmese_(Myanmar)"
dst_folder = "/content/pimcls/omniglot_myanmar"
ds_folder = "/content/pimcls/omniglot"

if not os.path.isdir(dst_folder):
    shutil.copytree(src_folder, dst_folder)

make_ds_on(dst_folder, ds_folder)


#Prepare training

In [None]:
cmd = """./train.py ../omniglot 
--model seresnet34 
--sched cosine 
--epochs 80 
--warmup-epochs 5 
--lr 0.4 
--reprob 0.5 
--remode pixel 
--batch-size 8 
--amp 
--log-wandb 
-j 4"""

cmd_line = cmd.replace("\n", "")
print(cmd_line)

#Training

In [None]:
%cd /content/pimcls/pim 
!$cmd_line

#Inference

In [None]:
def find_model_src():
    output_folder = "/content/pimcls/pim/output/train"
    dfs = os.listdir(output_folder)
    for df in dfs:
        result_folder = "{}/{}".format(output_folder, df)
        model_best_path = "{}/model_best.pth.tar".format(result_folder)
        if os.path.isfile(model_best_path):
            return model_best_path
    return None


In [None]:
os.makedirs("/content/pimcls/model", exist_ok=True)
model_src = find_model_src()
model_dst = "/content/pimcls/model/model_best.pth.tar"
if not os.path.isfile(model_dst): 
    shutil.copy(model_src, model_dst)


In [None]:
import torch
checkpoint = torch.load(model_dst, map_location='cpu')
print()
print(type(checkpoint))
print(checkpoint.keys())
print()

print("epoch  \t", checkpoint["epoch"])
print("arch.  \t", checkpoint["arch"])
print("version\t", checkpoint["version"])
print("metric \t", checkpoint["metric"])
print("args   \t",checkpoint["args"])


In [None]:
from timm.models import create_model

model = create_model(
    checkpoint["arch"],
    num_classes=checkpoint["args"].num_classes,
    in_chans=3,
    checkpoint_path=model_dst)
print(model)

python inference.py /imagenet/validation/ --model mobilenetv3_large_100 --checkpoint ./output/train/model_best.pth.tar

In [None]:
def do_inference(subfolder_name):
    %cd /content/pimcls/pim 

    output_folder = "{}".format(subfolder_name)
    print(output_folder)
    os.makedirs(output_folder, exist_ok=True)

    cmd = """./inference.py /content/pimcls/omniglot_myanmar/{} 
    --model seresnet34 
    --checkpoint /content/pimcls/model/model_best.pth.tar 
    --output_dir {}
    """

    cmd = cmd.format(subfolder_name, subfolder_name)
    cmd_line = cmd.replace("\n", "").replace("\t", "")
    print(cmd_line)

    !$cmd_line

In [None]:
do_inference("character01")

In [None]:
do_inference("character02")

In [None]:
do_inference("character03")

In [None]:
do_inference("character04")

In [None]:
do_inference("character34")

# model check

In [None]:
from PIL import Image
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform
import torch

def get_transform(model):
    config = resolve_data_config({}, model=model)
    transform = create_transform(**config)
    return transform, config

def pred_poss(model, img_path):
    transform, _ = get_transform(model)

    img = Image.open(img_path).convert('RGB')
    print(img.height, img.width)
    tensor = transform(img).unsqueeze(0)
    print(tensor.shape)

    with torch.no_grad():
        out = model(tensor)
    probabilities = torch.nn.functional.softmax(out[0], dim=0)
    print(probabilities.shape)
    #for ndx, p in enumerate(probabilities):
    #  print(ndx, p)

    print("top5")
    top5_prob, top5_catid = torch.topk(probabilities, 5)
    for i in range(top5_prob.size(0)):
        print(i, top5_catid[i], top5_prob[i].item())
    print()

    print("top10")
    top10_prob, top10_catid = torch.topk(probabilities, 10)
    for i in range(top10_prob.size(0)):
        print(i, top10_catid[i], top10_prob[i].item())

    return top10_prob, top10_catid

In [None]:
transform, config = get_transform(model)
print(config)
print(transform)

In [None]:
pred_poss(model, "/content/pimcls/omniglot_myanmar/character01/0770_01.png")

In [None]:
pred_poss(model, "/content/pimcls/omniglot_myanmar/character30/0799_04.png")