In [7]:
import numpy as np
import torch
from vit_pytorch import SimpleViT

# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = torch.device("mps")

In [2]:
import pandas as pd
from huggingface_hub import hf_hub_download
import joblib

drug_name_encoder = joblib.load(
    hf_hub_download("brainer/pill-identifier", "Condensed-Co-Graph-And-Size-Graph/drug_name_encoder.pkl")
)

id2label = pd.DataFrame(drug_name_encoder.classes_).to_dict()[0]
label2id = {v: k for k, v in id2label.items()}
id2label, label2id

({0: '0',
  1: '1',
  2: '100',
  3: '1000',
  4: '10000',
  5: '10002',
  6: '10003',
  7: '10004',
  8: '10005',
  9: '10006',
  10: '10007',
  11: '10008',
  12: '10009',
  13: '1001',
  14: '10010',
  15: '10011',
  16: '10012',
  17: '10013',
  18: '10014',
  19: '10015',
  20: '10016',
  21: '10017',
  22: '10018',
  23: '10019',
  24: '1002',
  25: '10020',
  26: '10021',
  27: '10022',
  28: '10023',
  29: '10024',
  30: '10025',
  31: '10026',
  32: '10027',
  33: '10028',
  34: '10029',
  35: '1003',
  36: '10030',
  37: '10031',
  38: '10032',
  39: '10033',
  40: '10034',
  41: '10035',
  42: '10036',
  43: '10037',
  44: '10038',
  45: '10039',
  46: '1004',
  47: '10040',
  48: '10041',
  49: '10042',
  50: '10043',
  51: '10044',
  52: '10045',
  53: '10046',
  54: '10047',
  55: '10048',
  56: '10049',
  57: '1005',
  58: '10050',
  59: '10051',
  60: '10052',
  61: '10053',
  62: '10054',
  63: '10055',
  64: '10056',
  65: '10057',
  66: '10058',
  67: '10059',
  68: 

In [8]:
model = SimpleViT(
    image_size=256,
    patch_size=32,
    num_classes=len(id2label),
    dim=1024,
    depth=6,
    heads=8,
    mlp_dim=2048
)

model.load_state_dict(torch.load("./model/pretrained-net.pt", map_location=device))
model

SimpleViT(
  (to_patch_embedding): Sequential(
    (0): Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=32, p2=32)
    (1): LayerNorm((3072,), eps=1e-05, elementwise_affine=True)
    (2): Linear(in_features=3072, out_features=1024, bias=True)
    (3): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
  )
  (transformer): Transformer(
    (norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
    (layers): ModuleList(
      (0-5): 6 x ModuleList(
        (0): Attention(
          (norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
          (attend): Softmax(dim=-1)
          (to_qkv): Linear(in_features=1024, out_features=1536, bias=False)
          (to_out): Linear(in_features=512, out_features=1024, bias=False)
        )
        (1): FeedForward(
          (net): Sequential(
            (0): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
            (1): Linear(in_features=1024, out_features=2048, bias=True)
            (2): GELU(approximate='no

In [9]:
model(torch.rand(1, 3, 256, 256))

tensor([[-0.2884, -0.2283, -0.4925,  ..., -0.4102, -0.5404,  0.4078]],
       grad_fn=<AddmmBackward0>)

# Fine tune with MLP
## Change last layer

In [10]:
import torch

for param in model.parameters():
    param.requires_grad = False

last_layer = torch.nn.Sequential(torch.nn.Linear(1024, out_features=len(label2id)), torch.nn.ReLU())
print(last_layer)

model.mlp_head = last_layer
model.mlp_head.requires_grad_(True)
model

Sequential(
  (0): Linear(in_features=1024, out_features=50783, bias=True)
  (1): ReLU()
)


SimpleViT(
  (to_patch_embedding): Sequential(
    (0): Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=32, p2=32)
    (1): LayerNorm((3072,), eps=1e-05, elementwise_affine=True)
    (2): Linear(in_features=3072, out_features=1024, bias=True)
    (3): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
  )
  (transformer): Transformer(
    (norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
    (layers): ModuleList(
      (0-5): 6 x ModuleList(
        (0): Attention(
          (norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
          (attend): Softmax(dim=-1)
          (to_qkv): Linear(in_features=1024, out_features=1536, bias=False)
          (to_out): Linear(in_features=512, out_features=1024, bias=False)
        )
        (1): FeedForward(
          (net): Sequential(
            (0): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
            (1): Linear(in_features=1024, out_features=2048, bias=True)
            (2): GELU(approximate='no

# Fine tune

In [11]:
# label2id keys to int

for key in list(label2id.keys()):
    label2id[int(key)] = label2id.pop(key)

label2id

{0: 0,
 1: 1,
 100: 2,
 1000: 3,
 10000: 4,
 10002: 5,
 10003: 6,
 10004: 7,
 10005: 8,
 10006: 9,
 10007: 10,
 10008: 11,
 10009: 12,
 1001: 13,
 10010: 14,
 10011: 15,
 10012: 16,
 10013: 17,
 10014: 18,
 10015: 19,
 10016: 20,
 10017: 21,
 10018: 22,
 10019: 23,
 1002: 24,
 10020: 25,
 10021: 26,
 10022: 27,
 10023: 28,
 10024: 29,
 10025: 30,
 10026: 31,
 10027: 32,
 10028: 33,
 10029: 34,
 1003: 35,
 10030: 36,
 10031: 37,
 10032: 38,
 10033: 39,
 10034: 40,
 10035: 41,
 10036: 42,
 10037: 43,
 10038: 44,
 10039: 45,
 1004: 46,
 10040: 47,
 10041: 48,
 10042: 49,
 10043: 50,
 10044: 51,
 10045: 52,
 10046: 53,
 10047: 54,
 10048: 55,
 10049: 56,
 1005: 57,
 10050: 58,
 10051: 59,
 10052: 60,
 10053: 61,
 10054: 62,
 10055: 63,
 10056: 64,
 10057: 65,
 10058: 66,
 10059: 67,
 1006: 68,
 10060: 69,
 10061: 70,
 10062: 71,
 10063: 72,
 10064: 73,
 10065: 74,
 10066: 75,
 10067: 76,
 10068: 77,
 10069: 78,
 1007: 79,
 10071: 80,
 10072: 81,
 10073: 82,
 10074: 83,
 10075: 84,
 10076: 

In [12]:
from datasets import load_dataset

dataset = load_dataset("brainer/drug_info", streaming=True)['train']
class_label = dataset.info.features['label']

Resolving data files:   0%|          | 0/20 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/20 [00:00<?, ?it/s]

In [13]:
# dataset['image']의 투명 하지 않은 부분을 bbox로 인식
# bbox의 좌표는 (x_min, y_min, x_max, y_max)로 표현
def get_bbox(image):
    """
    Find the bounding box of the non-transparent part of a PNG image using optimized method.

    :param image: PIL Image object.
    :return: Tuple (left, upper, right, lower) representing the bounding box.
    """
    if image.mode != 'RGBA':
        image = image.convert('RGBA')

    # Convert the image to a numpy array
    data = np.array(image)

    # Extract the alpha channel and find where it's non-zero
    alpha = data[:, :, 3]
    non_zero_indices = np.argwhere(alpha != 0)

    if non_zero_indices.size == 0:
        return None

    # Find the bounding box coordinates
    min_y, min_x = non_zero_indices.min(axis=0)
    max_y, max_x = non_zero_indices.max(axis=0)

    return [(min_x, min_y, max_x - min_x, max_y - min_y)]


def crop_image(image):
    x, y, w, h = get_bbox(image)[0]
    return image.crop((x, y, x + w, y + h))

In [14]:
from tqdm.notebook import tqdm

model.train()
model = model.to(device)
opt = torch.optim.NAdam(model.parameters(), lr=3e-4)
loss_fn = torch.nn.CrossEntropyLoss()


def collate_fn(batch):
    batch_image_list = []
    batch_label_list = []
    for _, data in enumerate(batch):
        image = crop_image(data["image"])
        image = image.convert("RGB")
        image = image.resize((256, 256))
        image = np.asarray(image).reshape(3, 256, 256)
        batch_image_list.append(image)
        batch_label_list.append(data["label"])
    return {'image': batch_image_list, 'label': batch_label_list}


batch_size = 16

torch_dataset = torch.utils.data.DataLoader(dataset, batch_size=batch_size, collate_fn=collate_fn)

for i, data in tqdm(enumerate(torch_dataset)):
    predict = model(torch.tensor(np.asarray(data['image'])).to(device))

    truth_label = [label2id[int(label)] for label in [class_label.int2str(i) for i in data['label']]]
    truth = torch.zeros(batch_size, len(label2id))

    for i, label in enumerate(truth_label):
        truth[i, label] = 1

    pred_label = []
    for label in torch.argmax(predict, dim=1):
        if len(class_label.names) > label:
            pred_label.append(label2id[int(class_label.int2str(label))])
        else:
            pred_label.append(label2id[int(class_label.int2str(0))])

    pred = torch.zeros(batch_size, len(label2id))
    for i, label in enumerate(pred_label):
        pred[i, label] = 1

    # print(f"pred: {pred}, truth: {truth}")
    opt.zero_grad()
    loss = loss_fn(pred.clone().detach().requires_grad_(True), truth.clone().detach().requires_grad_(True))
    loss.backward()
    opt.step()
    if loss.item() < 0.1:
        break
    if i % 100 == 0:
        print(f"loss: {loss.item()}")

0it [00:00, ?it/s]

In [16]:
model_path = "./model/dino-finetuned-net.pt"

torch.save(model, model_path)

from huggingface_hub import HfApi

api = HfApi()

api.upload_file(
    path_or_fileobj=model_path,
    path_in_repo=f"Feature_Extractor/dino-finetuned-net.pt",
    repo_id="brainer/PGPNet")

dino-finetuned-net.pt:   0%|          | 0.00/581M [00:00<?, ?B/s]

CommitInfo(commit_url='https://huggingface.co/brainer/PGPNet/commit/de529ff1d363b99daeddd774c43133346b6c5539', commit_message='Upload Feature_Extractor/dino-finetuned-net.pt with huggingface_hub', commit_description='', oid='de529ff1d363b99daeddd774c43133346b6c5539', pr_url=None, pr_revision=None, pr_num=None)

In [17]:
batch_size = 16
sample_label = torch.zeros(batch_size, len(label2id))
sample_label[0, 0] = 1
sample_label

tensor([[1., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])

In [None]:
sample = next(iter(torch_dataset))
sample

In [None]:
# [label2id[int(label)] for label in [class_label.int2str(i) for i in sample['label']]]


# predict_label = [label2id[int(label)] for label in [class_label.int2str(i) for i in torch.argmax(predict, dim=1)]]

pred = model(torch.rand(16, 3, 256, 256).to(device))
[class_label.int2str(int(i)) for i in torch.argmax(pred, dim=1)]

# Predict

In [None]:
from PIL import Image
import requests
from io import BytesIO

img_url = "https://img.freepik.com/free-photo/minimal-medicinal-pills-arrangement_23-2148892392.jpg?size=626&ext=jpg&ga=GA1.1.1412446893.1704499200&semt=sph"

response = requests.get(img_url)
img = Image.open(BytesIO(response.content))
img = img.resize(size=(256, 256))
model = model.to(device)
result = model(torch.from_numpy(np.asarray(img)).permute(2, 0, 1).unsqueeze(0).to(device))

In [None]:
id2label[int(torch.argmax(result))]