In [None]:
# 라이브러리 설치
!pip install --upgrade pip
!pip install torch torchvision torchaudio
!pip install ftfy regex tqdm
!pip install git+https://github.com/openai/CLIP.git

# GPU 사용 확인
import torch
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

In [6]:
import numpy as np
import matplotlib.pyplot as plt

from  tqdm import tqdm

import torch
import torchvision
import torchvision.transforms as T
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torchvision.utils import make_grid

In [7]:
!git clone https://github.com/lvyilin/pytorch-fgvc-dataset.git

Cloning into 'pytorch-fgvc-dataset'...
remote: Enumerating objects: 32, done.[K
remote: Counting objects: 100% (32/32), done.[K
remote: Compressing objects: 100% (19/19), done.[K
remote: Total 32 (delta 16), reused 29 (delta 13), pack-reused 0 (from 0)[K
Receiving objects: 100% (32/32), 13.70 KiB | 13.70 MiB/s, done.
Resolving deltas: 100% (16/16), done.


In [8]:
!mv pytorch-fgvc-dataset data

In [9]:
from data.aircraft import Aircraft

In [11]:
!mkdir data/aircraft

In [23]:
IMG_SIZE = 224

transform = T.Compose(
    [
        T.Resize((IMG_SIZE,IMG_SIZE)),
        T.ToTensor(),
        T.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
                         std=[0.26862954, 0.26130258, 0.27577711]),

    ]
)

In [24]:
train_ds = Aircraft('./data/aircraft', train=True, download=True, transform=transform)
test_ds = Aircraft('./data/aircraft', train=False, download=True, transform=transform)

In [31]:
image_ids, targets, classes, class_to_idx = train_ds.find_classes()

In [32]:
CLASSES = [c[:-1] for c in classes]
CLS2IDX = {c[:-1]:idx for c, idx in class_to_idx.items()}

In [39]:
torch.cuda.empty_cache()

In [45]:
train_dl2 = torch.utils.data.DataLoader(train_ds, batch_size=1,
                                      pin_memory=True)

test_dl = torch.utils.data.DataLoader(test_ds, batch_size=1,
                                      pin_memory=True)

In [46]:
import torch
import clip
from torch.utils.data import DataLoader
from tqdm import tqdm

# CLIP 모델 및 토크나이저 불러오기
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)

# 데이터 로더 준비
train_dl = DataLoader(train_ds, batch_size=32, shuffle=True, pin_memory=True)
test_dl = DataLoader(test_ds, batch_size=32, shuffle=False, pin_memory=True)

# CLASSES를 텍스트로 변환하여 텍스트 임베딩 생성
text_inputs = torch.cat([clip.tokenize(f"This is a photo of a {c}") for c in CLASSES]).to(device)
with torch.no_grad():
    text_features = model.encode_text(text_inputs)
    text_features /= text_features.norm(dim=-1, keepdim=True)  # 정규화

# 모델 평가 모드로 설정
model.eval()

# 훈련 데이터에서 이미지 임베딩 계산
def extract_image_features(data_loader, model):
    all_image_features = []
    all_labels = []

    for images, labels in tqdm(data_loader):
        images = images.to(device)
        with torch.no_grad():
            image_features = model.encode_image(images)
            image_features /= image_features.norm(dim=-1, keepdim=True)  # 정규화
        all_image_features.append(image_features)
        all_labels.append(labels)

    return torch.cat(all_image_features), torch.cat(all_labels)

# 추론 및 성능 평가 함수
def evaluate(model, data_loader, text_features, class_names):
    correct = 0
    total = 0

    for images, labels in tqdm(data_loader):
        images = images.to(device)
        labels = labels.to(device)

        # 이미지 임베딩 계산
        with torch.no_grad():
            image_features = model.encode_image(images)
            image_features /= image_features.norm(dim=-1, keepdim=True)

        # 이미지와 텍스트 임베딩 간의 유사도 계산
        similarity = (100.0 * image_features @ text_features.T)
        predictions = similarity.argmax(dim=-1)

        # 정확도 계산
        correct += (predictions == labels).sum().item()
        total += labels.size(0)

    accuracy = correct / total
    print(f"Accuracy: {accuracy * 100:.2f}%")
    return accuracy

# 이미지 특징 벡터와 성능 평가
print("Extracting features from training set...")
train_image_features, train_labels = extract_image_features(train_dl, model)

print("Evaluating test set...")
accuracy = evaluate(model, test_dl, text_features, CLASSES)

Extracting features from training set...



  0%|          | 0/209 [00:00<?, ?it/s][A
  0%|          | 1/209 [00:00<03:24,  1.02it/s][A
  1%|          | 2/209 [00:01<02:18,  1.49it/s][A
  1%|▏         | 3/209 [00:01<01:59,  1.73it/s][A
  2%|▏         | 4/209 [00:02<01:45,  1.94it/s][A
  2%|▏         | 5/209 [00:02<01:42,  1.99it/s][A
  3%|▎         | 6/209 [00:03<01:37,  2.09it/s][A
  3%|▎         | 7/209 [00:03<01:33,  2.16it/s][A
  4%|▍         | 8/209 [00:04<01:32,  2.17it/s][A
  4%|▍         | 9/209 [00:04<01:30,  2.20it/s][A
  5%|▍         | 10/209 [00:05<01:32,  2.15it/s][A
  5%|▌         | 11/209 [00:05<01:31,  2.16it/s][A
  6%|▌         | 12/209 [00:05<01:31,  2.15it/s][A
  6%|▌         | 13/209 [00:06<01:30,  2.17it/s][A
  7%|▋         | 14/209 [00:06<01:30,  2.16it/s][A
  7%|▋         | 15/209 [00:07<01:35,  2.04it/s][A
  8%|▊         | 16/209 [00:08<01:47,  1.79it/s][A
  8%|▊         | 17/209 [00:08<01:53,  1.69it/s][A
  9%|▊         | 18/209 [00:09<01:59,  1.60it/s][A
  9%|▉         | 19/209 [00:1

Evaluating test set...



  0%|          | 0/105 [00:00<?, ?it/s][A
  1%|          | 1/105 [00:00<00:53,  1.93it/s][A
  2%|▏         | 2/105 [00:00<00:48,  2.12it/s][A
  3%|▎         | 3/105 [00:01<00:45,  2.23it/s][A
  4%|▍         | 4/105 [00:01<00:45,  2.22it/s][A
  5%|▍         | 5/105 [00:02<00:45,  2.19it/s][A
  6%|▌         | 6/105 [00:03<00:54,  1.83it/s][A
  7%|▋         | 7/105 [00:03<00:59,  1.65it/s][A
  8%|▊         | 8/105 [00:04<01:04,  1.50it/s][A
  9%|▊         | 9/105 [00:05<01:05,  1.47it/s][A
 10%|▉         | 10/105 [00:06<01:07,  1.40it/s][A
 10%|█         | 11/105 [00:06<01:05,  1.44it/s][A
 11%|█▏        | 12/105 [00:07<00:56,  1.63it/s][A
 12%|█▏        | 13/105 [00:07<00:52,  1.77it/s][A
 13%|█▎        | 14/105 [00:08<00:47,  1.90it/s][A
 14%|█▍        | 15/105 [00:08<00:45,  1.98it/s][A
 15%|█▌        | 16/105 [00:08<00:45,  1.97it/s][A
 16%|█▌        | 17/105 [00:09<00:43,  2.04it/s][A
 17%|█▋        | 18/105 [00:09<00:41,  2.08it/s][A
 18%|█▊        | 19/105 [00:1

Accuracy: 17.97%





In [93]:
import torch.nn as nn

class SPPR(nn.Module):
    def __init__(self, feature_dim):
        super(SPPR, self).__init__()
        self.transform_new = nn.Sequential(
            nn.Conv1d(feature_dim, feature_dim, kernel_size=1),
            nn.BatchNorm1d(feature_dim),
            nn.ReLU()
        )
        self.transform_old = nn.Sequential(
            nn.Conv1d(feature_dim, feature_dim, kernel_size=1),
            nn.BatchNorm1d(feature_dim),
            nn.ReLU()
        )

    def forward(self, new_embeddings, old_prototypes):
        new_projected = self.transform_new(new_embeddings)
        old_projected = self.transform_old(old_prototypes)

        # 코사인 유사도로 관계 매트릭스 생성
        relation_matrix = torch.matmul(new_projected, old_projected.T)

        # 프로토타입 업데이트
        refined_prototypes = torch.matmul(relation_matrix, old_prototypes)
        return refined_prototypes

In [94]:
# SPPR 객체 생성
feature_dim = 512  # CLIP 모델의 특징 차원
sppr = SPPR(feature_dim=feature_dim).to(device)

In [121]:
# FGVC-Aircraft 데이터셋의 전체 클래스 수
total_classes = len(CLASSES)  # FGVC 데이터셋의 클래스 수
classes_per_session = 5  # 한 세션에 추가되는 클래스 수

# 세션별 클래스 분리
incremental_sessions = [
    CLASSES[i:i + classes_per_session]
    for i in range(0, total_classes, classes_per_session)
]

In [123]:
from torch.utils.data import DataLoader, Subset

def create_session_loader(dataset, session_classes, batch_size=32):
    # 세션 클래스에 해당하는 인덱스 필터링
    indices = [i for i in range(len(dataset)) if dataset[i][1] in session_classes]
    session_subset = Subset(dataset, indices)
    return DataLoader(session_subset, batch_size=batch_size, shuffle=True)

In [130]:
cached_labels = [label for _, label in train_ds.samples]

def create_session_loader(dataset, session_classes, batch_size=32):
    # 세션 클래스에 해당하는 인덱스 필터링
    indices = [i for i in range(len(dataset)) if dataset[i][1] in session_classes]
    session_subset = Subset(dataset, indices)
    return DataLoader(session_subset, batch_size=batch_size, shuffle=True)

# 세션별 데이터 로더 생성
session_loaders = [
    create_session_loader(dataset=train_ds, session_classes=session_classes, batch_size=32)
    for session_classes in incremental_sessions
]

KeyboardInterrupt: 

In [117]:
optimizer = torch.optim.Adam(sppr.parameters(), lr=1e-4)
loss_fn = nn.CrossEntropyLoss()

# 기존 프로토타입 초기화 (임의로 초기화 가능)
old_prototypes = torch.zeros(len(incremental_sessions[0]), feature_dim).to(device)

# 세션별 훈련 루프
for images, labels in session_loader:
    images, labels = images.to(device), labels.to(device)

    with torch.no_grad():
        new_embeddings = model.encode_image(images)
        new_embeddings /= new_embeddings.norm(dim=-1, keepdim=True)  # 정규화

    # 데이터 타입을 float32로 변환
    new_embeddings = new_embeddings.to(torch.float32)
    old_prototypes = old_prototypes.to(torch.float32)

    # 프로토타입 업데이트
    refined_prototypes = sppr(new_embeddings, old_prototypes)

    # 로짓 계산
    logits = torch.matmul(new_embeddings, refined_prototypes.T)
    loss = loss_fn(logits, labels)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    # 세션별 프로토타입 갱신
    old_prototypes = refined_prototypes.detach()

RuntimeError: Given groups=1, weight of size [512, 512, 1], expected input[1, 32, 512] to have 512 channels, but got 32 channels instead