데이터셋은 우선적으로는 CelebDFV2를 사용하고, 이후에 GTA-V 데이터를 합쳐보도록 하자.

이걸 어디부터 구현해야 할까...

일단 SigLIP을 불러오는 것부터 시작하자. Quantize를 해줘서 최대한 부담을 줄여주자.

In [1]:
!uv pip install --system lightning
!pip install -q lightning

[2mUsing Python 3.12.12 environment at: /usr[0m
[2mAudited [1m1 package[0m [2min 102ms[0m[0m


In [2]:
!uv pip install --system bitsandbytes
!pip install -q bitsandbytes

[2mUsing Python 3.12.12 environment at: /usr[0m
[2mAudited [1m1 package[0m [2min 104ms[0m[0m


In [3]:
import lightning.pytorch as L
import torch
from PIL import Image
from transformers import AutoProcessor, AutoModel, BitsAndBytesConfig

embed_size = 768

frame_token = 576

In [4]:
torch.cuda.empty_cache()
import gc
gc.collect()

90

솔직히, 왠만하면 원본 모델을 따라가고 싶지만... 일단 인코더를 더 경량으로 바꾼다.

나중에 성능 안 나오면 탓할 것 중 인코더가 늘었다.

In [5]:
import torch.nn as nn

In [6]:
import torch.nn.functional as F

class ViTEncoder(nn.Module):
    def __init__(self, embed_size=768, num_heads=12, dropout=0.1):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = embed_size // num_heads
        assert self.head_dim * num_heads == embed_size, "embed_size must be divisible by num_heads"

        self.q_proj = nn.Linear(embed_size, embed_size)
        self.k_proj = nn.Linear(embed_size, embed_size)
        self.v_proj = nn.Linear(embed_size, embed_size)
        self.out_proj = nn.Linear(embed_size, embed_size)
        self.dropout_p = dropout

        self.ln1 = nn.LayerNorm(embed_size)
        self.mlp = nn.Sequential(
            nn.Linear(embed_size, 4 * embed_size),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(4 * embed_size, embed_size),
            nn.Dropout(dropout),
        )
        self.ln2 = nn.LayerNorm(embed_size)

    def forward(self, x, return_head_contrib=False):
        # x: (batch_size, seq_len, embed_size)
        batch_size, seq_len, embed_size = x.shape

        # Project queries, keys, values
        q = self.q_proj(x) # (batch_size, seq_len, embed_size)
        k = self.k_proj(x) # (batch_size, seq_len, embed_size)
        v = self.v_proj(x) # (batch_size, seq_len, embed_size)

        # Split into multiple heads
        q = q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) # (b, num_heads, seq_len, head_dim)
        k = k.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) # (b, num_heads, seq_len, head_dim)
        v = v.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) # (b, num_heads, seq_len, head_dim)

        # Apply scaled dot product attention
        # dropout_p는 훈련 중일 때만 적용
        attn_output_raw = F.scaled_dot_product_attention(
            q, k, v,
            attn_mask=None,
            dropout_p=self.dropout_p if self.training else 0.0,
            is_causal=False
        )
        # attn_output: (b, num_heads, seq_len, head_dim)


        head_contrib = None
        if return_head_contrib:
            # out_proj_weight: [embed_size, embed_size] -> [num_heads, head_dim, embed_size]
            out_proj_weight = self.out_proj.weight.t().view(self.num_heads, self.head_dim, embed_size)
            # b: batch size, h: num_head, l: seq_len, k: head_dim, d: embed_dim
            head_contrib = torch.einsum("bhlk, hkd -> bhld", attn_output_raw, out_proj_weight)
            # -> result: [batch, num_head, seq_len, embed_dim] (-> [batch, num_head, seq_len, embed_dim])
            # Have to remove class token, and pool that with frame dimension

        # Concatenate heads and apply final linear projection
        attn_output = attn_output_raw.transpose(1, 2).contiguous().view(batch_size, seq_len, embed_size)
        attn_output = self.out_proj(attn_output) # (batch_size, seq_len, embed_size)


        x = self.ln1(x + attn_output) # Residual connection + LayerNorm
        x = self.ln2(x + self.mlp(x))

        if return_head_contrib:
            return x, head_contrib
        return x

In [7]:
import math
class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, dropout: float=0.1, max_len: int=32): # max_len을 32로 설정
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        # d_model에 맞춰 div_term 계산
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))

        # [max_len, d_model] 형상으로 생성
        pe = torch.zeros(max_len, d_model)
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)

        # [1, max_len, d_model]로 변경하여 Batch First 대응
        self.register_buffer('pe', pe.unsqueeze(0))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        x: [batch_size, frame * seq_len, d_model]
        """
        b, f, t, d = x.shape
        # [1, max_len, 1, d_model]로 -> frame과 d_model에만 적용한다
        # seq_len의 공간적 위치 정보는 이미 인코더에서 처리해주었다.
        curr_pe = self.pe[:, :f, :].unsqueeze(2)
        # 입력된 x의 길이만큼만 PE를 더함
        x = x + curr_pe
        return self.dropout(x)

In [8]:
import torch
from torchvision.transforms import v2

class GPUSigLIPProcessor:
    def __init__(self, processor):
        config = processor.image_processor

        # 1. 리사이즈 설정: Bilinear + Antialias=True가 핵심
        # Fast 프로세서가 텐서를 처리할 때 사용하는 로직과 일치시킵니다.
        self.resize = v2.Resize(
            size=(config.size['height'], config.size['width']),
            interpolation=v2.InterpolationMode.BILINEAR, # resample=2
            antialias=True # 오차를 줄이는 가장 중요한 설정
        )

        # 2. 정규화 설정
        # (x - 0.5) / 0.5 연산
        self.mean = torch.tensor(config.image_mean).view(1, 3, 1, 1)
        self.std = torch.tensor(config.image_std).view(1, 3, 1, 1)
        self.rescale_factor = config.rescale_factor

    def __call__(self, video_tensor):
        """
        video_tensor: (B, 3, T, H, W), uint8, GPU
        """
        b, c, t, h, w = video_tensor.shape
        device = video_tensor.device

        # 차원 변경 (B*T, C, H, W)
        x = video_tensor.permute(0, 2, 1, 3, 4)
        x = x.flatten(0, 1)

        # [Step 1] Resize (uint8 상태에서 수행하거나 float32에서 수행)
        # torchvision v2는 uint8 입력을 받아 내부적으로 고정밀 연산을 수행합니다.
        x = self.resize(x)

        # [Step 2] Float32 변환 및 Rescale (0~255 -> 0~1)
        x = x.to(torch.float32) * self.rescale_factor

        # [Step 3] Normalize (x - 0.5) / 0.5
        # mean, std를 캐싱하여 속도 최적화
        self.mean = self.mean.to(device)
        self.std = self.std.to(device)
        x = (x - self.mean) / self.std

        # [Step 4] 최종 모델 입력형태인 float16으로 반환
        return x.to(torch.float16)

In [33]:
class UNITE(nn.Module):
    def __init__(self, num_channel=3, num_cls=2, num_heads=12, max_len=32, dropout=0.1):
        super().__init__()

        model_id = "google/siglip2-base-patch16-384"
        self.vis_encoder = AutoModel.from_pretrained(
            model_id,
            device_map="auto",
            dtype=torch.bfloat16
        )
        self.embed_size = self.vis_encoder.config.vision_config.hidden_size
        processor = AutoProcessor.from_pretrained(model_id, use_fast=True)
        self.processor = GPUSigLIPProcessor(processor)
        self.num_heads = num_heads

        for para in self.vis_encoder.parameters():
            para.requires_grad = False
        self.vis_encoder.eval()

        self.class_token = nn.Parameter(torch.randn((self.embed_size,)), requires_grad=True)

        self.pos_embedding = PositionalEncoding(self.embed_size)
        self.first_encoder = ViTEncoder(self.embed_size, num_heads, dropout)
        self.encoders = nn.ModuleList([ViTEncoder(self.embed_size, num_heads, dropout) for _ in range(3)])
        self.mlp_head = nn.Linear(self.embed_size, num_cls)


    def forward(self, x, return_ad_param=False):
        self.vis_encoder.eval()

        # Input: [batch, c, frame, h, w]
        b, _, f, *_ = x.shape

        with torch.no_grad():
            # -> Preprocessing [batch * frame, c, h, w]
            x = self.processor(x)
            # -> Visual encoding [batch * frame, token/frame(576), dim/token (embed_size)]
            x = self.vis_encoder.vision_model(pixel_values=x).last_hidden_state
        # -> [batch, frame, token/frame, dim/token]
        x = x.reshape(b, f, -1, self.embed_size)
        x = self.pos_embedding(x)
        train_in = x # xi

        _, _, t, d = x.shape
        # Reshape for transformer
        # -> [batch, total token, dim/token]
        x = x.reshape(b, f*t, d)
        # Add class token
        cls_token = self.class_token.view(1, 1, -1).expand(b, -1, -1)
        x = torch.cat([cls_token, x], dim=1)
        P = None
        if return_ad_param:
            x, head_contrib = self.first_encoder(x, return_head_contrib=True)
            # head_contrib: [batch, head, total token + class token, dim/token]
            # -> [batch, head, total token, dim/token]
            head_contrib = head_contrib[:, :, 1:, :]
            # -> [batch, head, frame, token/frame, dim/token]
            head_contrib = head_contrib.view(b, self.num_heads, f, t, d)
            # -> mean pooling [batch, head, token/frame, dim/token]
            head_contrib = head_contrib.mean(dim=2) # A

            P = torch.einsum("bftd, bhtd -> bhf", train_in, head_contrib)
        else:
            x = self.first_encoder(x)

        for encoder in self.encoders:
            x = encoder(x)

        # Get only cls_token
        x = x[:, 0, :]
        x = x.view(b, -1)
        x = self.mlp_head(x)
        if return_ad_param:
            return x, P
        return x

In [34]:
class ADLoss(nn.Module):
    def __init__(self, num_cls=2, num_heads=12, max_len=32, delta_within=(0.01, -2.0), delta_between=0.5, eta=0.05):
        super().__init__()
        # C shape: [num_classes, num_heads, max_len]
        # 논문 식(3)에 따라 센터를 각 클래스별로 유지해야 함
        C = torch.zeros(num_cls, num_heads, max_len)
        self.register_buffer('C', C)

        self.num_cls = num_cls
        self.delta_within = torch.tensor(delta_within) # [0.01, -2.0] (True, Fake)
        self.delta_between = delta_between # 0.5
        self.eta = eta

    def forward(self, P, labels, log_detail=False):
        """
        P: [batch, num_heads, max_len] (Pooled Features)
        labels: [batch] (Class indices)
        """
        device = P.device
        self.delta_within = self.delta_within.to(device)

        P_norm = F.normalize(P.view(P.size(0), -1), p=2, dim=1).view_as(P)
        C_norm = F.normalize(self.C.view(self.num_cls, -1), p=2, dim=1).view_as(self.C)

        # --- 1. 센터 업데이트 (식 3) ---
        # 배치의 각 클래스별 평균을 구해서 업데이트
        for c in range(self.num_cls):
            mask = (labels == c)
            if mask.any():
                # 해당 클래스의 이번 배치 평균
                batch_class_mean = P_norm[mask].mean(dim=0) # [num_heads, max_len]
                # print(f"{self.C[c].shape, P.shape, P[mask].shape, batch_class_mean.shape=}")
                # 이동 평균 업데이트
                with torch.no_grad():
                    self.C[c] = (1 - self.eta) * self.C[c] + self.eta * batch_class_mean.detach()

        # --- 2. Within-class Loss (식 4) ---
        # 각 샘플과 자기 클래스 센터 사이의 거리
        # P: [B, H, F], self.C[labels]: [B, H, F]
        diff_within = P_norm - C_norm[labels]
        # L2 Norm 계산 (헤드와 프레임 차원에 대해)
        dist_within = torch.norm(diff_within, p=2, dim=(1, 2))

        # 각 샘플별 delta 적용
        loss_within = torch.relu(dist_within - self.delta_within[labels]).mean()

        # --- 3. Between-class Loss (식 5: 서로 다른 쌍에 대해 전부) ---
        # 클래스 센터들 간의 모든 쌍 거리 (Pairwise Distance) 계산
        # self.C: [num_classes, H*F]로 펼쳐서 계산하면 편리함
        C_flat = C_norm.view(self.num_cls, -1)

        # 모든 클래스 쌍 간의 차이 계산: [num_classes, num_classes, H*F]
        # Broadcasting 활용: (N, 1, D) - (1, N, D) -> (N, N, D)
        diff_between = C_flat.unsqueeze(1) - C_flat.unsqueeze(0)

        # 모든 쌍의 거리 행렬: [num_classes, num_classes]
        dist_matrix = torch.norm(diff_between, p=2, dim=2)

        # k != l 인 조건 (서로 다른 쌍) 추출을 위한 마스크
        # torch.triu를 써서 중복 계산(k,l과 l,k)을 피하고 자기 자신(k=l)도 제외함
        mask_between = torch.triu(torch.ones(self.num_cls, self.num_cls, device=device), diagonal=1).bool()

        # 서로 다른 클래스 쌍의 거리들만 추출
        different_pairs_dist = dist_matrix[mask_between]

        # 식 (5) 적용: max(delta - dist, 0)
        loss_between = torch.relu(self.delta_between - different_pairs_dist).sum()
        if log_detail:
            return loss_within + loss_between, loss_within, loss_between

        return loss_within + loss_between

In [35]:
from torchmetrics.classification import Accuracy, AveragePrecision, Precision, Recall

In [36]:
class LitUNITEClassifier(L.LightningModule):
    def __init__(
            self, num_cls=2, num_heads=12, max_len=32, dropout=0.1,
            delta_within=(0.01, -2.0), delta_between=0.5, eta=0.05,
            lambda_1=0.5, lambda_2=0.5, lr=1e-4, decay_steps=1000,
        ):
        super().__init__()
        self.save_hyperparameters()
        self.model = UNITE(
            num_cls=num_cls,
            num_heads=num_heads,
            max_len=max_len,
            dropout=dropout,
        )
        self.ce_loss = nn.CrossEntropyLoss()
        self.ad_loss = ADLoss(
            num_cls=num_cls,
            num_heads=num_heads,
            max_len=max_len,
            delta_within=delta_within,
            delta_between=delta_between,
            eta=eta,
        )
        self.lambda_1 = lambda_1
        self.lambda_2 = lambda_2

        self.lr = lr
        self.decay_steps = decay_steps # Set this in respect to batch size; original batch size was 32

        self.acc = Accuracy(task='multiclass', num_classes=num_cls)
        self.ap = AveragePrecision(task='multiclass', num_classes=num_cls)
        self.precision = Precision(task='multiclass', num_classes=num_cls)
        self.recall = Recall(task='multiclass', num_classes=num_cls)

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        logit, P = self.model(x, return_ad_param=True)
        loss_ad, within, between = self.ad_loss(P, y, log_detail=True)
        loss_ce = self.ce_loss(logit, y)
        loss = loss_ce * self.lambda_1 + loss_ad * self.lambda_2
        self.log("train/loss_ad", loss_ad, logger=True)
        self.log("train/loss_ad/loss_within", within, logger=True)
        self.log("train/loss_ad/loss_between", between, logger=True)
        self.log("train/loss_ce", loss_ce, logger=True)
        self.log("train/loss", loss, prog_bar=True, logger=True)

        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logit, P = self.model(x, return_ad_param=True)
        loss_ad = self.ad_loss(P, y)
        loss_ce = self.ce_loss(logit, y)
        loss = loss_ce * self.lambda_1 + loss_ad * self.lambda_2
        self.log("val/loss_ad", loss_ad, logger=True)
        self.log("val/loss_ce", loss_ce, logger=True)
        self.log("val/loss", loss, prog_bar=True, logger=True)

        self.acc(logit, y)
        self.ap(logit, y)
        self.precision(logit, y)
        self.recall(logit, y)
        self.log("val/acc", self.acc, logger=True)
        self.log("val/ap", self.ap, logger=True)
        self.log("val/precision", self.precision, logger=True)
        self.log("val/recall", self.recall, logger=True)

    def test_step(self, batch, batch_idx):
        x, y = batch
        logit = self.model(x)

        self.acc(logit, y)
        self.ap(logit, y)
        self.precision(logit, y)
        self.recall(logit, y)
        self.log("test/acc", self.acc, logger=True)
        self.log("test/ap", self.ap, logger=True)
        self.log("test/precision", self.precision, logger=True)
        self.log("test/recall", self.recall, logger=True)

    def configure_optimizers(self):
        optim = torch.optim.AdamW(self.model.parameters(), lr=self.lr)
        scheduler = torch.optim.lr_scheduler.StepLR(optim, self.decay_steps, gamma=0.5)

        return {
            "optimizer": optim,
            "lr_scheduler": {
                "scheduler": scheduler,
                "interval": "step",
            },
        }

In [37]:
import kagglehub
from pathlib import Path
import cv2
import pandas as pd
import torch
import numpy as np
import math

from torch.utils.data import DataLoader, Dataset, random_split

class CelebDFDataset(Dataset):
    def __init__(self, is_test: bool, path_str:str, length=32, size=(384, 384), transform=None):
        """
        Args:
            is_test (bool): True면 테스트 셋, False면 트레인 셋 로드
            length (int): 시퀀스 길이 (기본 32)
            transform: 이미지 전처리 (Optional)
        """

        self.path = Path(path_str)
        self.is_test = is_test
        self.length = length
        self.size = size
        self.transform = transform

        # 모든 mp4 파일 검색
        self.files = list(self.path.glob("*/*.mp4"))

        # 테스트 비디오 리스트 로드
        # 파일 형식: [1|0] [path] (예: 1 YouTube-real/00170.mp4)
        txt_path = self.path / "List_of_testing_videos.txt"
        test_df = pd.read_csv(txt_path, sep=" ", header=None, names=["label", "path"])

        # 비교를 위해 테스트 파일 경로들을 Set으로 변환 (검색 속도 향상)
        # Windows/Linux 경로 구분자 통일을 위해 '/'로 replace 처리
        test_files_set = set(test_df["path"].apply(lambda x: x.replace("\\", "/")).values)

        self.samples = [] # (video_path, chunk_index, label) 튜플을 저장할 리스트

        print(f"Processing metadata for {'Test' if is_test else 'Train'} set...")

        for next_file in self.files:
            # 데이터셋 루트 기준 상대 경로 (예: YouTube-real/00170.mp4)
            rel_path = str(next_file.relative_to(self.path)).replace("\\", "/")

            # 현재 파일이 테스트 리스트에 있는지 확인
            is_in_test_list = rel_path in test_files_set

            # 요청한 Split(Train/Test)과 맞지 않으면 스킵
            if self.is_test != is_in_test_list:
                continue

            # 레이블 결정 (0: Real, 1: Fake)
            # 폴더명 기반 판단
            if "YouTube-real" in rel_path or "Celeb-real" in rel_path:
                label = 0
            elif "Celeb-synthesis" in rel_path:
                label = 1
            else:
                continue # 알 수 없는 폴더는 제외

            cap = cv2.VideoCapture(str(next_file))
            if cap.isOpened():
                frame_cnt = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
                cap.release()

                if frame_cnt <= 0: continue

                # 영상 하나를 여러 데이터(Chunk)로 쪼개기
                # Stride=2 (하나 걸러 하나), Length=32
                # 데이터 하나당 필요한 원본 프레임 구간 = 약 64 프레임
                # 전체 프레임에서 Stride 2로 뽑았을 때 나오는 유효 프레임 수
                effective_frames = math.ceil(frame_cnt / 2)

                # 영상 하나에서 나오는 데이터 개수 (올림 처리)
                num_chunks = math.ceil(effective_frames / self.length)

                for i in range(num_chunks):
                    self.samples.append({
                        "video_path": str(next_file),
                        "chunk_idx": i,
                        "label": label
                    })
            else:
                print(f"Cannot open video: {next_file}")

        print(f"Loaded {len(self.samples)} samples from {len(self.files)} files.")

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        meta = self.samples[idx]
        video_path = meta["video_path"]
        chunk_idx = meta["chunk_idx"]
        label = meta["label"]

        cap = cv2.VideoCapture(video_path)

        # 시작 프레임 계산 (Chunk 인덱스 * 시퀀스 길이 * Stride 2)
        start_frame = chunk_idx * self.length * 2

        frames = []

        # 시작 위치로 이동
        cap.set(cv2.CAP_PROP_POS_FRAMES, start_frame)

        # 최대 32개의 유효 프레임을 모을 때까지 반복
        # Stride 2를 구현하기 위해 읽으면서 짝수 번째만 저장하거나, 2프레임씩 건너뜀
        # 여기서는 안전하게 프레임을 순차적으로 읽으며 인덱스를 체크합니다.

        # 읽어야 할 최대 범위 (32개를 모으기 위해 최대 64프레임 탐색)
        for i in range(self.length * 2):
            ret, frame = cap.read()
            if not ret:
                break # 영상 끝 도달

            # 짝수 번째(0, 2, 4...) 프레임만 수집 (하나 걸러 하나)
            if i % 2 == 0:
                frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                # 필요하다면 여기서 cv2.resize 등을 수행
                frame = cv2.resize(frame, self.size)
                frames.append(frame)

            if len(frames) == self.length:
                break

        cap.release()

        # 패딩 로직: 32개가 안 되면 마지막 프레임으로 채움
        if len(frames) < self.length:
            if len(frames) == 0:
                # 매우 드문 경우 (파일은 열리는데 프레임이 없는 경우 등) -> 0으로 채움
                # 일반적으로는 발생하지 않아야 함
                h, w = 224, 224 # 기본 크기 가정
                frames = [np.zeros((h, w, 3), dtype=np.uint8) for _ in range(self.length)]
            else:
                last_frame = frames[-1]
                while len(frames) < self.length:
                    frames.append(last_frame.copy())

        # Numpy array 변환: (T, H, W, C) -> (32, H, W, 3)
        frames_np = np.array(frames)

        if self.transform:
            # Transform이 있다면 적용 (보통 이미지 단위로 적용)
            # Video transform 라이브러리를 쓴다면 그대로 넘겨야 함
            # 여기서는 간단히 Torch Tensor 변환 예시
            pass

        # To Tensor: (T, H, W, C) -> (C, T, H, W) 형태로 변환 (PyTorch Video 모델 표준)
        # 0~255 값을 0~1로 정규화는 추후에 할 예정이므로 삭제
        frames_tensor = torch.from_numpy(frames_np).permute(3, 0, 1, 2)

        return frames_tensor, torch.tensor(label)

In [38]:
class CelebDFDataModule(L.LightningDataModule):
    def __init__(self, length=32, batch_size=32, num_workers=8):
        super().__init__()
        self.length = length
        self.batch_size = batch_size
        self.num_workers = num_workers

    def prepare_data(self):
        kagglehub.dataset_download("reubensuju/celeb-df-v2")

    def setup(self, stage=None):
        path = kagglehub.dataset_download("reubensuju/celeb-df-v2")
        if stage == "fit" or stage is None:
            train_full = CelebDFDataset(is_test=False, path_str=path)
            self.celebdf_train, self.celebdf_val = random_split(train_full, [0.9, 0.1])
        if stage == "test" or stage is None:
            self.celebdf_test = CelebDFDataset(is_test=True, path_str=path)

    def train_dataloader(self):
        return DataLoader(self.celebdf_train, num_workers=self.num_workers, batch_size=self.batch_size, pin_memory=True)
    def val_dataloader(self):
        return DataLoader(self.celebdf_val, num_workers=self.num_workers, batch_size=self.batch_size, pin_memory=True)
    def test_dataloader(self):
        return DataLoader(self.celebdf_test, num_workers=self.num_workers, batch_size=self.batch_size)

In [15]:
import wandb
from google.colab import userdata
wandb_key = userdata.get('wandb_api')
wandb.login(key=wandb_key)

[34m[1mwandb[0m: [wandb.login()] Using explicit session credentials for https://api.wandb.ai.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mdhnam0502[0m ([33mdhnam0502-likelion[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [40]:
BATCH_SIZE = 20
DECAY_STEPS = (1000 * 32) // BATCH_SIZE

datamodule = CelebDFDataModule(batch_size=BATCH_SIZE)
lit_classifier = LitUNITEClassifier(decay_steps=DECAY_STEPS)
lit_classifier = torch.compile(lit_classifier)

In [41]:
from lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint
from lightning.pytorch.callbacks.lr_monitor import LearningRateMonitor
from pytorch_lightning.loggers import WandbLogger
from lightning.pytorch.callbacks import TQDMProgressBar


wandb_logger = WandbLogger(project="UNITE_deepfake_classification", name="baseline", log_model=True)

ckpt = ModelCheckpoint(monitor="val/acc", mode="max", save_last=True)
lr_monitor = LearningRateMonitor(logging_interval='epoch')

trainer =  L.Trainer(
    max_epochs=25,
    # profiler='simple',
    logger=wandb_logger,
    callbacks=[ckpt, lr_monitor, TQDMProgressBar()],
    precision='bf16-mixed'
)



INFO:pytorch_lightning.utilities.rank_zero:Using bfloat16 Automatic Mixed Precision (AMP)
INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores


In [32]:
wandb.finish()

0,1
epoch,▁▁▁▁
lr-AdamW,▁
train/loss,▁█▄▄
train/loss_ad,█▁▆▄
train/loss_ad/loss_between,▁▇██
train/loss_ad/loss_within,█▁▆▄
train/loss_ce,▁█▃▅
trainer/global_step,▁▃▄▆█

0,1
epoch,0.0
lr-AdamW,0.0001
train/loss,1.31554
train/loss_ad,2.3158
train/loss_ad/loss_between,0.49422
train/loss_ad/loss_within,1.82158
train/loss_ce,0.31529
trainer/global_step,199.0


In [None]:
trainer.fit(lit_classifier, datamodule=datamodule)

Using Colab cache for faster access to the 'celeb-df-v2' dataset.


Using Colab cache for faster access to the 'celeb-df-v2' dataset.
Processing metadata for Train set...


INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:lightning.pytorch.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Loaded 38569 samples from 6529 files.


/usr/local/lib/python3.12/dist-packages/lightning/pytorch/utilities/model_summary/model_summary.py:242: Precision bf16-mixed is not supported by the model summary.  Estimated model size in MB will not be accurate. Using 32 bits instead.


Sanity Checking: |          | 0/? [00:00<?, ?it/s]