In [1]:
# !pip install datasets

In [2]:
import huggingface_hub

huggingface_hub.login("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")

Token will not been saved to git credential helper. Pass `add_to_git_credential=True` if you want to set the git credential as well.
Token is valid (permission: write).
Your token has been saved to /Users/brainer/.cache/huggingface/token
Login successful


In [3]:
from datasets import load_dataset

dataset = load_dataset("brainer/drug_info", streaming=True)['train']
dataset

Resolving data files:   0%|          | 0/20 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/20 [00:00<?, ?it/s]

IterableDataset({
    features: ['image', 'label', 'bbox'],
    n_shards: 20
})

In [4]:
from huggingface_hub import hf_hub_download
import joblib

drug_name_encoder = joblib.load(
    hf_hub_download("brainer/pill-identifier", "Condensed-Co-Graph-And-Size-Graph/drug_name_encoder.pkl")
)
drug_name_encoder

In [5]:
import pandas as pd

id2label = pd.DataFrame(drug_name_encoder.classes_).to_dict()[0]
label2id = {v: k for k, v in id2label.items()}
id2label, label2id

({0: '195900001',
  1: '195900032',
  2: '195900043',
  3: '196000001',
  4: '196000003',
  5: '196000008',
  6: '196000011',
  7: '196200002',
  8: '196200032',
  9: '196200034',
  10: '196200046',
  11: '196300001',
  12: '196300019',
  13: '196300064',
  14: '196400005',
  15: '196400037',
  16: '196400046',
  17: '196400099',
  18: '196500004',
  19: '196500051',
  20: '196600007',
  21: '196600011',
  22: '196600012',
  23: '196700049',
  24: '196700060',
  25: '196800048',
  26: '196800078',
  27: '196900043',
  28: '197000005',
  29: '197000037',
  30: '197000040',
  31: '197000049',
  32: '197000050',
  33: '197000053',
  34: '197000068',
  35: '197000076',
  36: '197000079',
  37: '197000102',
  38: '197000104',
  39: '197000196',
  40: '197000208',
  41: '197000211',
  42: '197000212',
  43: '197000215',
  44: '197100015',
  45: '197100073',
  46: '197100081',
  47: '197100088',
  48: '197100097',
  49: '197100165',
  50: '197200016',
  51: '197200084',
  52: '197200091',
  5

In [6]:
import numpy as np


# dataset['image']의 투명 하지 않은 부분을 bbox로 인식
# bbox의 좌표는 (x_min, y_min, x_max, y_max)로 표현
def get_bbox(image):
    """
    Find the bounding box of the non-transparent part of a PNG image using optimized method.

    :param image: PIL Image object.
    :return: Tuple (left, upper, right, lower) representing the bounding box.
    """
    if image.mode != 'RGBA':
        image = image.convert('RGBA')

    # Convert the image to a numpy array
    data = np.array(image)

    # Extract the alpha channel and find where it's non-zero
    alpha = data[:, :, 3]
    non_zero_indices = np.argwhere(alpha != 0)

    if non_zero_indices.size == 0:
        return None

    # Find the bounding box coordinates
    min_y, min_x = non_zero_indices.min(axis=0)
    max_y, max_x = non_zero_indices.max(axis=0)

    return [(min_x, min_y, max_x - min_x, max_y - min_y)]


def crop_image(image):
    x, y, w, h = get_bbox(image)[0]
    return image.crop((x, y, x + w, y + h))

In [7]:
# from PIL import ImageDraw
# 
# image = sample["image"]
# annotations = sample["bbox"]
# draw = ImageDraw.Draw(image)
# 
# print(f"annotations: {annotations}")
# 
# class_idx = sample["label"]
# x, y, w, h = annotations
# draw.rectangle((x, y, x + w, y + h), outline="red", width=5)
# draw.text((x, y), id2label[class_idx], fill="white")
# 
# image

# Fine Tuning

In [9]:
from tqdm.notebook import tqdm
import torch
from vit_pytorch import ViT, Dino

# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = torch.device("mps")

model = ViT(
    image_size=256,
    patch_size=32,
    num_classes=len(id2label),
    dim=1024,
    depth=6,
    heads=8,
    mlp_dim=2048
)

model.to(device)

learner = Dino(
    model,
    image_size=256,
    hidden_layer='to_latent',  # hidden layer name or index, from which to extract the embedding
    projection_hidden_size=256,  # projector network hidden dimension
    projection_layers=4,  # number of layers in projection network
    num_classes_K=65336,  # output logits dimensions (referenced as K in paper)
    student_temp=0.9,  # student temperature
    teacher_temp=0.04,  # teacher temperature, needs to be annealed from 0.04 to 0.07 over 30 epochs
    local_upper_crop_scale=0.4,  # upper bound for local crop - 0.4 was recommended in the paper 
    global_lower_crop_scale=0.5,  # lower bound for global crop - 0.5 was recommended in the paper
    moving_average_decay=0.9,  # moving average of encoder - paper showed anywhere from 0.9 to 0.999 was ok
    center_moving_average_decay=0.9,
    # moving average of teacher centers - paper showed anywhere from 0.9 to 0.999 was ok
)

opt = torch.optim.NAdam(learner.parameters(), lr=3e-4)


def collate_fn(batch):
    batch_list = []
    for i, data in enumerate(batch):
        image = crop_image(data["image"])
        image = image.convert("RGB")
        image = image.resize((256, 256))
        image = np.asarray(image).reshape(3, 256, 256)
        batch_list.append(image)
    return batch_list


torch_dataset = torch.utils.data.DataLoader(dataset, batch_size=16, collate_fn=collate_fn)

# for i in tqdm(range(100)):
for i, data in tqdm(enumerate(torch_dataset)):
    # print(f"data: {data}")
    loss = learner(torch.tensor(np.asarray(data)).float().to(device))
    opt.zero_grad()
    loss.backward()
    opt.step()
    learner.update_moving_average()  # update moving average of teacher encoder and teacher centers
    if i % 64 == 0:
        print(f"epoch: {i}, loss: {loss}")

# save your improved networks
torch.save(model.state_dict(), './model/pretrained-net.pt')



0it [00:00, ?it/s]

epoch: 0, loss: 11.022136688232422
epoch: 64, loss: 10.640785217285156


KeyboardInterrupt: 

In [None]:
import gc
import torch

torch.clear_autocast_cache()
gc.collect()
torch.cuda.empty_cache()

In [None]:
# from google.colab import drive
# drive.mount('/content/drive')

In [None]:
# origin_paht = "detr-resnet-50_finetuned_pill"
# target_path = "/content/drive/MyDrive/Dacon"

# import os
# os.system(f"cp -r {origin_paht} {target_path}")