In [1]:
# !pip install datasets

In [2]:
import huggingface_hub

huggingface_hub.login("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")

Token will not been saved to git credential helper. Pass `add_to_git_credential=True` if you want to set the git credential as well.
Token is valid (permission: write).
Your token has been saved to /Users/brainer/.cache/huggingface/token
Login successful


In [3]:
from datasets import load_dataset

dataset = load_dataset("brainer/drug_info", streaming=True)['train']
dataset

Resolving data files:   0%|          | 0/20 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/20 [00:00<?, ?it/s]

IterableDataset({
    features: ['image', 'label', 'bbox'],
    n_shards: 20
})

In [4]:
from huggingface_hub import hf_hub_download
import joblib

drug_name_encoder = joblib.load(
    hf_hub_download("brainer/pill-identifier", "Condensed-Co-Graph-And-Size-Graph/drug_name_encoder.pkl")
)
drug_name_encoder

In [5]:
import pandas as pd

id2label = pd.DataFrame(drug_name_encoder.classes_).to_dict()[0]
label2id = {v: k for k, v in id2label.items()}
id2label, label2id

({0: '0',
  1: '1',
  2: '100',
  3: '1000',
  4: '10000',
  5: '10002',
  6: '10003',
  7: '10004',
  8: '10005',
  9: '10006',
  10: '10007',
  11: '10008',
  12: '10009',
  13: '1001',
  14: '10010',
  15: '10011',
  16: '10012',
  17: '10013',
  18: '10014',
  19: '10015',
  20: '10016',
  21: '10017',
  22: '10018',
  23: '10019',
  24: '1002',
  25: '10020',
  26: '10021',
  27: '10022',
  28: '10023',
  29: '10024',
  30: '10025',
  31: '10026',
  32: '10027',
  33: '10028',
  34: '10029',
  35: '1003',
  36: '10030',
  37: '10031',
  38: '10032',
  39: '10033',
  40: '10034',
  41: '10035',
  42: '10036',
  43: '10037',
  44: '10038',
  45: '10039',
  46: '1004',
  47: '10040',
  48: '10041',
  49: '10042',
  50: '10043',
  51: '10044',
  52: '10045',
  53: '10046',
  54: '10047',
  55: '10048',
  56: '10049',
  57: '1005',
  58: '10050',
  59: '10051',
  60: '10052',
  61: '10053',
  62: '10054',
  63: '10055',
  64: '10056',
  65: '10057',
  66: '10058',
  67: '10059',
  68: 

In [6]:
import numpy as np


# dataset['image']의 투명 하지 않은 부분을 bbox로 인식
# bbox의 좌표는 (x_min, y_min, x_max, y_max)로 표현
def get_bbox(image):
    """
    Find the bounding box of the non-transparent part of a PNG image using optimized method.

    :param image: PIL Image object.
    :return: Tuple (left, upper, right, lower) representing the bounding box.
    """
    if image.mode != 'RGBA':
        image = image.convert('RGBA')

    # Convert the image to a numpy array
    data = np.array(image)

    # Extract the alpha channel and find where it's non-zero
    alpha = data[:, :, 3]
    non_zero_indices = np.argwhere(alpha != 0)

    if non_zero_indices.size == 0:
        return None

    # Find the bounding box coordinates
    min_y, min_x = non_zero_indices.min(axis=0)
    max_y, max_x = non_zero_indices.max(axis=0)

    return [(min_x, min_y, max_x - min_x, max_y - min_y)]


def crop_image(image):
    x, y, w, h = get_bbox(image)[0]
    return image.crop((x, y, x + w, y + h))

In [7]:
# from PIL import ImageDraw
# 
# image = sample["image"]
# annotations = sample["bbox"]
# draw = ImageDraw.Draw(image)
# 
# print(f"annotations: {annotations}")
# 
# class_idx = sample["label"]
# x, y, w, h = annotations
# draw.rectangle((x, y, x + w, y + h), outline="red", width=5)
# draw.text((x, y), id2label[class_idx], fill="white")
# 
# image

# Fine Tuning

In [10]:
from tqdm.notebook import tqdm
import torch
from vit_pytorch import Dino, SimpleViT

# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = torch.device("mps")

model = SimpleViT(
    image_size=256,
    patch_size=32,
    num_classes=len(id2label),
    dim=1024,
    depth=6,
    heads=8,
    mlp_dim=2048
)

model.to(device)

learner = Dino(
    model,
    image_size=256,
    hidden_layer='to_latent',  # hidden layer name or index, from which to extract the embedding
    projection_hidden_size=256,  # projector network hidden dimension
    projection_layers=4,  # number of layers in projection network
    num_classes_K=65336,  # output logits dimensions (referenced as K in paper)
    student_temp=0.9,  # student temperature
    teacher_temp=0.04,  # teacher temperature, needs to be annealed from 0.04 to 0.07 over 30 epochs
    local_upper_crop_scale=0.4,  # upper bound for local crop - 0.4 was recommended in the paper 
    global_lower_crop_scale=0.5,  # lower bound for global crop - 0.5 was recommended in the paper
    moving_average_decay=0.9,  # moving average of encoder - paper showed anywhere from 0.9 to 0.999 was ok
    center_moving_average_decay=0.9,
    # moving average of teacher centers - paper showed anywhere from 0.9 to 0.999 was ok
)

opt = torch.optim.NAdam(learner.parameters(), lr=3e-4)


def collate_fn(batch):
    batch_list = []
    for i, data in enumerate(batch):
        image = crop_image(data["image"])
        image = image.convert("RGB")
        image = image.resize((256, 256))
        image = np.asarray(image).reshape(3, 256, 256)
        batch_list.append(image)
    return batch_list


torch_dataset = torch.utils.data.DataLoader(dataset, batch_size=32, collate_fn=collate_fn)

# for i in tqdm(range(100)):
for i, data in tqdm(enumerate(torch_dataset)):
    loss = learner(torch.tensor(np.asarray(data)).float().to(device))
    opt.zero_grad()
    loss.backward()
    opt.step()
    learner.update_moving_average()  # update moving average of teacher encoder and teacher centers
    if i % 8 == 0:
        print(f"epoch: {i}, loss: {loss}")

# save your improved networks
torch.save(model.state_dict(), './model/pretrained-net.pt')

0it [00:00, ?it/s]

epoch: 0, loss: 11.02713680267334
epoch: 8, loss: 11.03703784942627
epoch: 16, loss: 11.029053688049316
epoch: 24, loss: 10.992862701416016
epoch: 32, loss: 10.930920600891113
epoch: 40, loss: 10.856046676635742
epoch: 48, loss: 10.778966903686523
epoch: 56, loss: 10.708025932312012
epoch: 64, loss: 10.647597312927246
epoch: 72, loss: 10.590744018554688
epoch: 80, loss: 10.53799819946289
epoch: 88, loss: 10.49332046508789
epoch: 96, loss: 10.450684547424316
epoch: 104, loss: 10.416836738586426
epoch: 112, loss: 10.391324996948242
epoch: 120, loss: 10.376193046569824
epoch: 128, loss: 10.378252983093262
epoch: 136, loss: 10.40432071685791
epoch: 144, loss: 10.466605186462402
epoch: 152, loss: 10.579427719116211
epoch: 160, loss: 10.771158218383789
epoch: 168, loss: 10.993287086486816
epoch: 176, loss: 10.80816650390625
epoch: 184, loss: 10.647680282592773
epoch: 192, loss: 10.44790267944336
epoch: 200, loss: 10.27054214477539
epoch: 208, loss: 10.119525909423828
epoch: 216, loss: 10.000

In [11]:
import gc
import torch

torch.clear_autocast_cache()
gc.collect()
torch.cuda.empty_cache()
torch.mps.empty_cache()

In [None]:
# from google.colab import drive
# drive.mount('/content/drive')

In [None]:
# origin_paht = "detr-resnet-50_finetuned_pill"
# target_path = "/content/drive/MyDrive/Dacon"

# import os
# os.system(f"cp -r {origin_paht} {target_path}")