In [2]:
import os 
from PIL import Image 
from arguments import parser 
import torch 
import torch.nn as nn 
import torch.nn.functional as F 
import numpy as np 
import pandas as pd 
import matplotlib.pyplot as plt 
from sklearn.manifold import TSNE
from datasets import create_dataset 
from torch.utils.data import DataLoader
from utils.metrics import MetricCalculator, loco_auroc
from accelerate import Accelerator
from omegaconf import OmegaConf
import seaborn as sns 
from main import torch_seed



torch_seed(42)
os.environ['CUDA_VISIBLE_DEVICES'] = '0' 
default_setting = './configs/default/mvtecad.yaml'
model_setting = './configs/model/cfgcad.yaml'
cfg = parser(True,default_setting, model_setting)


model  = __import__('models').__dict__[cfg.MODEL.method](
        backbone = cfg.MODEL.backbone,
        **cfg.MODEL.params
        ).to('cuda')
device = 'cuda'


loader_dict = {}
accelerator = Accelerator()
for cn in cfg.DATASET.class_names:
    trainset, testset = create_dataset(
        dataset_name  = cfg.DATASET.dataset_name,
        datadir       = cfg.DATASET.datadir,
        class_name    = cn,
        img_size      = cfg.DATASET.img_size,
        mean          = cfg.DATASET.mean,
        std           = cfg.DATASET.std,
        aug_info      = cfg.DATASET.aug_info,
        **cfg.DATASET.get('params',{})
    )
    trainloader = DataLoader(
        dataset     = trainset,
        batch_size  = cfg.DATASET.batch_size,
        num_workers = cfg.DATASET.num_workers,
        shuffle     = True 
    )    

    testloader = DataLoader(
            dataset     = testset,
            batch_size  = 8,
            num_workers = cfg.DATASET.num_workers,
            shuffle     = False 
        )    
    
    loader_dict[cn] = {'train':trainloader,'test':testloader}    
# model = model.to('cuda')


 Experiment Name : .-Continual_True-online_False



In [3]:
from CL import CL_Transformer
sparsity_config = cfg.CONTINUAL.method.params

cl_manager = CL_Transformer(model, accelerator.device, sparsity_config, replace_percentage=0.2)


# 마스킹된 파라미터 분석
def analyze_mask_distribution(cl_manager):
    total_params = 0
    masked_params = 0
    layer_stats = {}
    
    print("=== CL_Transformer 마스킹 분석 ===")
    
    for name, param in cl_manager.model.named_parameters():
        if name in cl_manager.mask:
            # 현재 레이어의 마스크 정보 가져오기
            mask = cl_manager.mask[name]
            total = param.numel()
            masked = (mask == 0).sum().item()  # 0인 값(마스킹된 값)의 개수
            active = (mask == 1).sum().item()  # 1인 값(활성화된 값)의 개수
            
            # 레이어 이름에서 가장 상위 모듈명 추출
            module_name = name.split('.')[0]
            if module_name not in layer_stats:
                layer_stats[module_name] = {'total': 0, 'masked': 0, 'active': 0}
            
            # 통계 업데이트
            layer_stats[module_name]['total'] += total
            layer_stats[module_name]['masked'] += masked
            layer_stats[module_name]['active'] += active
            
            total_params += total
            masked_params += masked
            
            # 개별 레이어 정보 출력
            print(f"레이어: {name}")
            print(f"  전체 파라미터: {total:,}")
            print(f"  마스킹된 파라미터: {masked:,} ({masked/total*100:.2f}%)")
            print(f"  활성화된 파라미터: {active:,} ({active/total*100:.2f}%)")
            
    # 모듈별 통계 출력
    print("\n=== 모듈별 마스킹 통계 ===")
    for module_name, stats in layer_stats.items():
        total = stats['total']
        masked = stats['masked']
        active = stats['active']
        print(f"모듈: {module_name}")
        print(f"  전체 파라미터: {total:,}")
        print(f"  마스킹된 파라미터: {masked:,} ({masked/total*100:.2f}%)")
        print(f"  활성화된 파라미터: {active:,} ({active/total*100:.2f}%)")
    
    # 전체 통계 출력
    print("\n=== 전체 마스킹 통계 ===")
    print(f"전체 파라미터: {total_params:,}")
    print(f"마스킹된 파라미터: {masked_params:,} ({masked_params/total_params*100:.2f}%)")
    print(f"활성화된 파라미터: {total_params-masked_params:,} ({(total_params-masked_params)/total_params*100:.2f}%)")
    
    # sparsity_config 설정 출력
    print("\n=== Sparsity 설정 ===")
    print(f"기본 스파시티: {cl_manager.sparsity_config.get('default_sparsity', 'N/A')}")
    print("레이어별 스파시티 설정:")
    for layer, sparsity in cl_manager.sparsity_config.get('layer_specific', {}).items():
        print(f"  {layer}: {sparsity}")
    
    return {
        'total_params': total_params,
        'masked_params': masked_params,
        'layer_stats': layer_stats
    }

# 분석 실행
mask_stats = analyze_mask_distribution(cl_manager)

Identifying learnable layers (focusing on nn.Linear/Conv2d, ignoring nn.LayerNorm)...
Identified Learnable Layers: ['backbone._conv_stem', 'backbone._blocks.0._depthwise_conv', 'backbone._blocks.0._se_reduce', 'backbone._blocks.0._se_expand', 'backbone._blocks.0._project_conv', 'backbone._blocks.1._depthwise_conv', 'backbone._blocks.1._se_reduce', 'backbone._blocks.1._se_expand', 'backbone._blocks.1._project_conv', 'backbone._blocks.2._expand_conv', 'backbone._blocks.2._depthwise_conv', 'backbone._blocks.2._se_reduce', 'backbone._blocks.2._se_expand', 'backbone._blocks.2._project_conv', 'backbone._blocks.3._expand_conv', 'backbone._blocks.3._depthwise_conv', 'backbone._blocks.3._se_reduce', 'backbone._blocks.3._se_expand', 'backbone._blocks.3._project_conv', 'backbone._blocks.4._expand_conv', 'backbone._blocks.4._depthwise_conv', 'backbone._blocks.4._se_reduce', 'backbone._blocks.4._se_expand', 'backbone._blocks.4._project_conv', 'backbone._blocks.5._expand_conv', 'backbone._blocks.5._

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math

# --- inc_net.py 에서 가져온 CosineLinear (이전 답변의 간소화 버전 사용) ---
class CosineLinear(nn.Module):
    def __init__(self, in_features, out_features, M_target_dim=None, use_RP=False, device='cuda'):
        super(CosineLinear, self).__init__()
        self.in_features_original = in_features
        self.out_features = out_features
        self.device = device
        self.use_RP = use_RP
        self.W_rand = None

        current_in_features_for_weight = M_target_dim if use_RP and M_target_dim is not None and M_target_dim > 0 else in_features
        self.weight = nn.Parameter(torch.Tensor(out_features, current_in_features_for_weight).to(self.device))
        self.sigma = nn.Parameter(torch.Tensor(1).to(self.device))
        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.weight.size(1))
        self.weight.data.uniform_(-stdv, stdv)
        if self.sigma is not None:
            self.sigma.data.fill_(1)

    def forward(self, input_features_L):
        if self.use_RP:
            if self.W_rand is not None:
                projected_features_M = F.relu(input_features_L @ self.W_rand)
            else:
                projected_features_M = input_features_L
            logits = F.linear(projected_features_M, self.weight)
        else:
            normalized_input = F.normalize(input_features_L, p=2, dim=1)
            normalized_weight = F.normalize(self.weight, p=2, dim=1)
            logits = F.linear(normalized_input, normalized_weight)

        if self.sigma is not None:
            logits = self.sigma * logits
        return {'logits': logits}

class PatchRanPACAnomalyDetector:
    def __init__(self, num_classes, original_feature_dim_L,
                 M_target_dim=None, use_RP_flag=False, device='cuda'):
        self.num_classes = num_classes # 정상 + 학습된 결함 유형 수
        self.original_feature_dim_L = original_feature_dim_L
        self.M_target_dim = M_target_dim
        self.use_RP_flag = use_RP_flag
        self.device = device

        self.W_rand = None
        self.fc_layer = CosineLinear(
            in_features=self.original_feature_dim_L,
            out_features=self.num_classes,
            M_target_dim=self.M_target_dim,
            use_RP=self.use_RP_flag,
            device=self.device
        ).to(self.device)

        if self.use_RP_flag and self.M_target_dim is not None and self.M_target_dim > 0:
            self.W_rand = torch.randn(
                self.original_feature_dim_L,
                self.M_target_dim
            ).to(self.device)
            self.fc_layer.W_rand = self.W_rand
        else:
            self.fc_layer.W_rand = None

        M_for_GQ = self.M_target_dim if (self.use_RP_flag and self.M_target_dim is not None and self.M_target_dim > 0) else self.original_feature_dim_L
        self.G = torch.zeros(M_for_GQ, M_for_GQ).to(self.device)
        self.Q = torch.zeros(M_for_GQ, self.num_classes).to(self.device)
        # Wo는 train_ridge_regression_head를 통해 학습되어 self.fc_layer.weight에 설정됨

    def _calculate_entropy(self, probabilities):
        probabilities = probabilities + 1e-9
        entropy = -torch.sum(probabilities * torch.log(probabilities), dim=-1)
        return entropy

    def _calculate_max_softmax_prob(self, probabilities):
        max_prob, _ = torch.max(probabilities, dim=-1)
        return max_prob
    
    def update_label_idx_mapping(self, list_of_image_labels):
        if not hasattr(self, 'label_idx_mapping'):
            self.label_idx_mapping = {}
            
        # Get unique labels from the current batch
        unique_labels = torch.unique(list_of_image_labels)
        
        # Process each unique label
        for label in unique_labels:
            if label not in self.label_idx_mapping:
                # If the label doesn't exist in the mapping, add it with the next available index
                next_idx = len(self.label_idx_mapping)
                self.label_idx_mapping[label.item()] = next_idx
                print(f"Added new label {label} with index {next_idx}")
        
        print(f"Current label mapping: {self.label_idx_mapping}")

    def train_ridge_regression_head(self, list_of_image_patch_features_L, list_of_image_labels):
        """
        여러 이미지로부터 추출된 패치 특징들로 G, Q를 누적하고 Wo를 계산.
        list_of_image_patch_features_L: 각 요소가 한 이미지의 (num_patches, L_dim) 패치 특징 텐서인 리스트.
        list_of_image_labels: 각 요소가 해당 이미지의 클래스 레이블(스칼라)인 리스트.
        """
        print(f"\n--- Training head with new batch of images ---")
        all_patches_H_for_task = []
        all_patch_onehot_labels_for_task = []
        
        # Initialize label_idx_mapping if it doesn't exist                
        self.update_label_idx_mapping(list_of_image_labels)
        

        for img_idx, image_patches_L in enumerate(list_of_image_patch_features_L):
            image_label = self.label_idx_mapping[list_of_image_labels[img_idx].item()]
            image_patches_L = image_patches_L.to(self.device) # (num_patches, L_dim)

            # 현재 이미지의 모든 패치에 대해 RP 적용 (필요시)
            if self.use_RP_flag and self.W_rand is not None:
                image_patches_H = F.relu(image_patches_L @ self.W_rand) # (num_patches, M_dim)
            else:
                image_patches_H = image_patches_L # (num_patches, L_dim)

            all_patches_H_for_task.append(image_patches_H)

            # 모든 패치는 해당 이미지의 레이블을 상속
            num_patches = image_patches_H.shape[0]
            patch_labels = torch.full((num_patches,), image_label, dtype=torch.long, device=self.device)
            patch_onehot_labels = F.one_hot(patch_labels, num_classes=self.num_classes).float()
            all_patch_onehot_labels_for_task.append(patch_onehot_labels)

        if not all_patches_H_for_task:
            print("No patch features to process for training.")
            return None

        # 현재 태스크(또는 배치)의 모든 패치 특징과 레이블을 하나로 합침
        current_batch_all_patches_H = torch.cat(all_patches_H_for_task, dim=0)
        current_batch_all_patch_onehot_labels = torch.cat(all_patch_onehot_labels_for_task, dim=0)

        # 누적 통계량 G, Q 업데이트
        self.G += current_batch_all_patches_H.T @ current_batch_all_patches_H
        self.Q += current_batch_all_patches_H.T @ current_batch_all_patch_onehot_labels
        print(f"Updated G shape: {self.G.shape}, Updated Q shape: {self.Q.shape}")

        # Lambda 최적화 (RanPAC.py의 optimise_ridge_parameter 간소화 버전)
        # 실제로는 현재 배치(current_batch_all_patches_H)의 일부를 검증용으로 사용해야 함
        # 여기서는 현재 배치의 모든 패치 특징으로 lambda를 찾고, 전체 G,Q에 적용
        current_task_lambda = self._optimise_ridge_for_current_batch(current_batch_all_patches_H, current_batch_all_patch_onehot_labels)

        try:
            current_G_dim = self.G.size(0)
            Wo_final_transposed = torch.linalg.solve(
                self.G + current_task_lambda * torch.eye(current_G_dim, device=self.device),
                self.Q
            )
            Wo_final = Wo_final_transposed.T
        except Exception as e:
            print(f"Error solving for Wo: {e}. Using pseudo-inverse.")
            G_reg = self.G + current_task_lambda * torch.eye(self.G.size(0), device=self.device)
            Wo_final_transposed = torch.linalg.lstsq(G_reg, self.Q).solution
            Wo_final = Wo_final_transposed.T

        if Wo_final is not None:
            self.fc_layer.weight.data = Wo_final
            print(f"Updated fc_layer weight with Wo, shape: {self.fc_layer.weight.data.shape}")
        return Wo_final

    def _optimise_ridge_for_current_batch(self, H_features_batch, Y_onehot_batch, ridges=None):
        """ 현재 배치의 특징과 레이블로 lambda 최적화 """
        if ridges is None: ridges = 10.0**np.arange(-3, 4)
        num_samples = H_features_batch.shape[0]
        if num_samples < 20: return 1.0 # 샘플 부족 시 기본값

        perm = torch.randperm(num_samples, device=self.device)
        H_shuffled, Y_shuffled = H_features_batch[perm], Y_onehot_batch[perm]
        split_idx = int(num_samples * 0.8)
        if split_idx == 0 or split_idx == num_samples: return 1.0

        H_train, H_val = H_shuffled[:split_idx], H_shuffled[split_idx:]
        Y_train, Y_val = Y_shuffled[:split_idx], Y_shuffled[split_idx:]
        if H_train.shape[0] == 0 or H_val.shape[0] == 0: return 1.0

        Q_train_batch = H_train.T @ Y_train
        G_train_batch = H_train.T @ H_train
        best_lambda, min_mse = ridges[0], float('inf')

        for ridge_lambda in ridges:
            try:
                Wo_temp = torch.linalg.solve(
                    G_train_batch + ridge_lambda * torch.eye(G_train_batch.size(0), device=self.device),
                    Q_train_batch)
                Y_pred_val = H_val @ Wo_temp
                mse = F.mse_loss(Y_pred_val, Y_val)
                if mse < min_mse: min_mse, best_lambda = mse, ridge_lambda
            except: continue
        print(f"Optimal lambda for current batch: {best_lambda} (MSE: {min_mse:.4f})")
        return best_lambda

    def predict_image_anomaly(self, image_patches_L, alpha=0.5, beta=0.5, aggregation_method='max'):
        """
        한 이미지의 패치 특징들(L차원)을 입력받아 이미지 레벨의 이상 점수와 예측을 반환.
        image_patches_L: (num_patches, L_dim) 텐서.
        aggregation_method: 'max' 또는 'mean'으로 패치 점수 집계 방식 선택.
        """
        self.fc_layer.eval()
        if image_patches_L.ndim == 2: # 단일 이미지의 패치들 (num_patches, L_dim)
            image_patches_L = image_patches_L.unsqueeze(0) # (1, num_patches, L_dim)으로 만듦

        batch_size, num_patches, _ = image_patches_L.shape
        # 모든 패치를 하나의 배치로 처리하기 위해 reshape
        # (batch_size * num_patches, L_dim)
        flat_patches_L = image_patches_L.reshape(-1, self.original_feature_dim_L).to(self.device)

        with torch.no_grad():
            # 1. 로짓 계산 (CosineLinear 내부에서 RP 자동 처리)
            output_dict = self.fc_layer(flat_patches_L)
            patch_logits = output_dict['logits'] # (B*num_patches, num_classes)

            # 2. 소프트맥스 확률 및 관련 지표 계산
            patch_probabilities = F.softmax(patch_logits, dim=-1)
            patch_predicted_classes = torch.argmax(patch_probabilities, dim=-1)
            patch_entropy = self._calculate_entropy(patch_probabilities)
            patch_confidence = self._calculate_max_softmax_prob(patch_probabilities)

            # 3. 각 패치에 대한 결합 이상 점수
            # patch_anomaly_scores_combined = alpha * (1 - patch_confidence) + beta * patch_entropy
            patch_anomaly_scores_combined = 1 - patch_confidence


            # 4. 이미지 레벨로 집계
            # (batch_size, num_patches) 형태로 다시 reshape
            patch_anomaly_scores_combined_reshaped = patch_anomaly_scores_combined.reshape(batch_size, num_patches)
            patch_predicted_classes_reshaped = patch_predicted_classes.reshape(batch_size, num_patches)
            patch_entropy_reshaped = patch_entropy.reshape(batch_size, num_patches)
            patch_confidence_reshaped = patch_confidence.reshape(batch_size, num_patches)

            if aggregation_method == 'max':
                image_level_anomaly_score = torch.max(patch_anomaly_scores_combined_reshaped, dim=1)[0]
                # 이미지 레벨 예측은 가장 높은 이상 점수를 가진 패치의 예측 클래스를 따르거나,
                # 또는 패치 예측들의 다수결(mode)을 따를 수 있음. 여기서는 가장 이상한 패치 기준.
                # 또는 "정상" 클래스(예: 0)가 아닌 다른 클래스로 예측된 패치가 하나라도 있으면 해당 결함으로 판단.
                # 아래는 예시: 가장 높은 이상 점수를 가진 패치의 예측 클래스를 이미지 레벨 예측으로.
                indices_max_anomaly_patch = torch.argmax(patch_anomaly_scores_combined_reshaped, dim=1)
                image_level_predicted_class = patch_predicted_classes_reshaped[torch.arange(batch_size), indices_max_anomaly_patch]

            elif aggregation_method == 'mean':
                image_level_anomaly_score = torch.mean(patch_anomaly_scores_combined_reshaped, dim=1)
                # 평균 점수 사용 시 이미지 레벨 클래스 예측은 더 복잡할 수 있음 (예: 패치 예측의 최빈값)
                image_level_predicted_class = torch.mode(patch_predicted_classes_reshaped, dim=1)[0]
            else:
                raise ValueError("Unsupported aggregation_method.")

        # 패치별 상세 정보도 반환하여 anomaly localization 등에 활용 가능
        return {
            "image_level_anomaly_score": image_level_anomaly_score, # (batch_size,)
            "image_level_predicted_class": image_level_predicted_class, # (batch_size,)
            "patch_level_predicted_classes": patch_predicted_classes_reshaped, # (batch_size, num_patches)
            "patch_level_anomaly_scores": patch_anomaly_scores_combined_reshaped, # (batch_size, num_patches)
            "patch_level_entropy": patch_entropy_reshaped, # (batch_size, num_patches)
            "patch_level_confidence": patch_confidence_reshaped # (batch_size, num_patches)
        }
        
import scipy.ndimage as ndimage

class RescaleSegmentor:
    def __init__(self, device, target_size=224):
        self.device = device
        self.target_size = target_size
        self.smoothing = 4

    def convert_to_segmentation(self, patch_scores):

        with torch.no_grad():
            if isinstance(patch_scores, np.ndarray):
                patch_scores = torch.from_numpy(patch_scores)
            _scores = patch_scores.to(self.device)
            _scores = _scores.unsqueeze(1)
            _scores = F.interpolate(
                _scores, size=self.target_size, mode="bilinear", align_corners=False
            )
            _scores = _scores.squeeze(1)
            patch_scores = _scores.cpu().numpy()

        return [
            ndimage.gaussian_filter(patch_score, sigma=self.smoothing)
            for patch_score in patch_scores
        ]
        


In [None]:
from utils.metrics import MetricCalculator, loco_auroc
model.eval()
img_level = MetricCalculator(metric_list = ['auroc','average_precision'])
pix_level = MetricCalculator(metric_list = ['auroc','average_precision'])
cls_level = MetricCalculator(metric_list = ['auroc','average_precision'])
segmentor = RescaleSegmentor(device='cuda')

# 모델 설정 파라미터
ORIGINAL_FEATURE_DIM_L = 768
PROJECTION_DIM_M = 12000 # M > L
USE_RP = True
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

# 데이터셋에서 클래스 수 동적으로 파악
unique_classes = set()
for class_name, loaders in loader_dict.items():
    trainloader = loaders['train']
    for _, _, cls in trainloader:
        unique_classes.update(cls.numpy())
NUM_CLASSES_TOTAL = len(unique_classes)
print(f"Detected {NUM_CLASSES_TOTAL} unique classes in the dataset")

# 1. Detector 객체 생성
detector = PatchRanPACAnomalyDetector(
    num_classes=NUM_CLASSES_TOTAL,
    original_feature_dim_L=ORIGINAL_FEATURE_DIM_L,
    M_target_dim=PROJECTION_DIM_M,
    use_RP_flag=USE_RP,
    device=DEVICE
)

# 2. 학습 데이터 생성 (이미지별 패치 특징 리스트)
with torch.no_grad():
    # Continual learning을 위해 loader_dict에서 순차적으로 trainloader 받아옴
    for class_name, loaders in loader_dict.items():
        trainloader = loaders['train']
        
        outputs_list = [] 
        cls_list = [] 
        for imgs, labels, cls in trainloader:
            outputs = model.embed_img(imgs.to('cuda'))
            outputs_list.append(outputs.detach().cpu())
            cls_list.append(cls.detach().cpu())

        cls_list = torch.concat(cls_list)
        outputs_list = torch.concat(outputs_list)

        # 각 클래스별로 continual learning 수행
        detector.train_ridge_regression_head(outputs_list, cls_list)
        print(f"Trained detector on class: {class_name}")

    # 테스트 데이터로 평가
    for img, label, class_label, gts in testloader:
        patches_L = model.embed_img(img.to('cuda'))    
        results = detector.predict_image_anomaly(patches_L, aggregation_method='max')
        
        score, score_map = results['image_level_anomaly_score'], results['patch_level_anomaly_scores']
        score_map = np.array(segmentor.convert_to_segmentation(score_map.unsqueeze(-1)))
        
        pix_level.update(score_map, gts.type(torch.int))
        img_level.update(score, label.type(torch.int))
        cls_level.update(results['image_level_predicted_class'], pd.Series(class_label.numpy()).map(detector.label_idx_mapping).values)

In [None]:
from sklearn.metrics import classification_report

# 각 클래스별로 테스트 데이터 평가
for class_name, loaders in loader_dict.items():
    testloader = loaders['test']
    print(f"Testing on class: {class_name}")
    
    for img, label, class_label, gts in testloader:
        patches_L = model.embed_img(img.to('cuda'))    
        results = detector.predict_image_anomaly(patches_L, aggregation_method='max')
        
        score, score_map = results['image_level_anomaly_score'], results['patch_level_anomaly_scores']
        score_map = np.array(segmentor.convert_to_segmentation(score_map.unsqueeze(-1)))
        
        pix_level.update(score_map, gts.type(torch.int))
        img_level.update(score, label.type(torch.int))
        cls_level.update(results['image_level_predicted_class'], pd.Series(class_label.numpy()).map(detector.label_idx_mapping).values)

# 모든 레벨에 대한 메트릭 계산
i_results, p_results = img_level.compute(), pix_level.compute()

# 이미지 레벨 및 픽셀 레벨 메트릭 결과 출력
print(f"Image-level metrics: AUROC: {i_results['auroc']:.3f}, AP: {i_results['average_precision']:.3f}")
print(f"Pixel-level metrics: AUROC: {p_results['auroc']:.3f}, AP: {p_results['average_precision']:.3f}")

# 클래스 분류 메트릭을 위해 classification_report 사용
preds = np.concatenate(cls_level.preds)
targets = np.concatenate(cls_level.targets)
print("\nClass-level classification report:")
print(classification_report(targets, preds))


In [None]:
import os 
from PIL import Image 
from arguments import parser 
import torch 
import torch.nn as nn 
import torch.nn.functional as F 
import numpy as np 
import pandas as pd 
import matplotlib.pyplot as plt 
from sklearn.manifold import TSNE
from datasets import create_dataset 
from torch.utils.data import DataLoader
from utils.metrics import MetricCalculator, loco_auroc
from accelerate import Accelerator
from omegaconf import OmegaConf
import seaborn as sns 
from main import torch_seed


torch_seed(42)
os.environ['CUDA_VISIBLE_DEVICES'] = '0' 
default_setting = './configs/default/mvtecad_15.yaml'
model_setting = './configs/model/simplenet.yaml'
cfg = parser(True,default_setting, model_setting)

model  = __import__('models').__dict__[cfg.MODEL.method](
        backbone = cfg.MODEL.backbone,
        **cfg.MODEL.params
        ).to('cuda')

loader_dict = {}
accelerator = Accelerator()
for cn in cfg.DATASET.class_names:
    trainset, testset = create_dataset(
        dataset_name  = cfg.DATASET.dataset_name,
        datadir       = cfg.DATASET.datadir,
        class_name    = cn,
        img_size      = cfg.DATASET.img_size,
        mean          = cfg.DATASET.mean,
        std           = cfg.DATASET.std,
        aug_info      = cfg.DATASET.aug_info,
        **cfg.DATASET.get('params',{})
    )
    trainloader = DataLoader(
        dataset     = trainset,
        batch_size  = cfg.DATASET.batch_size,
        num_workers = cfg.DATASET.num_workers,
        shuffle     = True 
    )    

    testloader = DataLoader(
            dataset     = testset,
            batch_size  = 8,
            num_workers = cfg.DATASET.num_workers,
            shuffle     = False 
        )    
    
    loader_dict[cn] = {'train':trainloader,'test':testloader}    


 Experiment Name : test-all-continual-scheduler-Continual_True-online_False



In [2]:
self = model 

for images, label, class_label, gts in testloader:
    images = images.to('cuda')
    break 

In [None]:
feat = self._embed(img, evaluation=False)[0]
true_feats = self.pre_projection(feat)

In [31]:
features = self.forward_modules["feature_aggregator"](images)
features = [features[layer] for layer in self.layers_to_extract_from]
for f in features:
    print(f.shape)

torch.Size([8, 512, 28, 28])
torch.Size([8, 1024, 14, 14])


In [32]:
for i, feat in enumerate(features):
    if len(feat.shape) == 3:
        B, L, C = feat.shape
        features[i] = feat.reshape(B, int(math.sqrt(L)), int(math.sqrt(L)), C).permute(0, 3, 1, 2)

for f in features:
    print(f.shape)

torch.Size([8, 512, 28, 28])
torch.Size([8, 1024, 14, 14])


In [None]:
features = [
            self.patch_maker.patchify(x, return_spatial_info=True) for x in features
        ]
features = [x[0] for x in features]
for f in features:
    print(f.shape)

torch.Size([8, 784, 512, 3, 3])
torch.Size([8, 196, 1024, 3, 3])


In [61]:
import torch
import torch.nn.functional as F
import einops
from timm import create_model

class VitBackbone(torch.nn.Module):
    def __init__(self, model_name='vit_base_patch16_224.orig_in21k', device='cuda'):
        super(VitBackbone, self).__init__()
        self.model = create_model(model_name, pretrained=True).to(device)
        
        # Set all parameters to non-trainable
        for param in self.model.parameters():
            param.requires_grad = False
    
    def forward(self, x, layers_to_extract=None):
        if layers_to_extract is None:
            # Default: extract features from all transformer blocks
            layers_to_extract = list(range(len(self.model.blocks)))
        
        # Initial processing
        x = self.model.patch_embed(x)
        
        # Add positional embedding without cls token
        if hasattr(self.model, 'pos_embed'):
            # Skip the class token position (index 0)
            pos_embed = self.model.pos_embed[:, 1:, :]
            x = x + pos_embed
        
        x = self.model.pos_drop(x)
        
        # Store intermediate features
        features = []        
        # Process through transformer blocks
        for i, block in enumerate(self.model.blocks):
            x = block(x)
            if i in layers_to_extract:
                features.append(x)
        
        # Process features to correct shape
        for j in range(len(features)):
            B, PxP, C = features[j].shape
            
            # Calculate P (assuming square patches)
            P = int(PxP ** 0.5)
            
            # Reshape to spatial format
            features[j] = einops.rearrange(features[j], 'b (h w) c -> b c h w', h=P, w=P)
            
        return features


# Create backbone and extract features
backbone = VitBackbone().to('cuda')
features = backbone(images, layers_to_extract=[3, 6, 9])