In [1]:
from dataset import create_dataloaders

train_loader = create_dataloaders("../data/Training/01.원천데이터", batch_size=6, test_ratio = 0, image_size=160, workers=16)
val_loader = create_dataloaders( "../data/Validation/01.원천데이터", batch_size=6, test_ratio = 0, image_size=160, workers=16)

In [2]:
import torch

DEVICE = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using {DEVICE} device")

Using cuda device


In [3]:
from multi_modal import MultiModalNet

multi_modal_checkpoint = torch.load("multi_modal_24_12_15.pth", map_location=DEVICE)
teacher_model= MultiModalNet(num_classes=1).to(DEVICE)
teacher_model.load_state_dict(multi_modal_checkpoint, strict=False)

Using cuda device


_IncompatibleKeys(missing_keys=['video_net.layer1.spatial_conv.weight', 'video_net.layer1.bn1.weight', 'video_net.layer1.bn1.bias', 'video_net.layer1.bn1.running_mean', 'video_net.layer1.bn1.running_var', 'video_net.layer1.temporal_conv.weight', 'video_net.layer1.bn2.weight', 'video_net.layer1.bn2.bias', 'video_net.layer1.bn2.running_mean', 'video_net.layer1.bn2.running_var', 'video_net.layer2.spatial_conv.weight', 'video_net.layer2.bn1.weight', 'video_net.layer2.bn1.bias', 'video_net.layer2.bn1.running_mean', 'video_net.layer2.bn1.running_var', 'video_net.layer2.temporal_conv.weight', 'video_net.layer2.bn2.weight', 'video_net.layer2.bn2.bias', 'video_net.layer2.bn2.running_mean', 'video_net.layer2.bn2.running_var', 'video_net.layer3.spatial_conv.weight', 'video_net.layer3.bn1.weight', 'video_net.layer3.bn1.bias', 'video_net.layer3.bn1.running_mean', 'video_net.layer3.bn1.running_var', 'video_net.layer3.temporal_conv.weight', 'video_net.layer3.bn2.weight', 'video_net.layer3.bn2.bias', 

In [4]:
import torch.nn as nn
from vision_model import R2Plus1D_Block

class mono_modal(nn.Module):
    def __init__(self, num_classes):
        super(mono_modal, self).__init__()
        self.layer1 = R2Plus1D_Block(3, 128, stride = 2)
        self.layer2 = R2Plus1D_Block(128, 64, stride = 4)

        self.pool = nn.AdaptiveAvgPool3d((1, 1, 1))

        self.fc_sequence = nn.Sequential(
            nn.Linear(64, 64),
            nn.ReLU(),
            nn.Linear(64, num_classes)
        )

    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)

        x = self.pool(x)
        x = x.view(x.size(0), -1)

        x = self.fc_sequence(x)
        return x 

In [5]:
student_model = mono_modal(num_classes=1).to(DEVICE)

In [6]:
import torch.optim as optim

learning_rate=0.0001 
optimizer = optim.Adam(student_model.parameters(), lr=learning_rate)

In [7]:
import torch.nn.functional as F
from tqdm import tqdm
import torch.nn as nn
import torch
from sklearn.metrics import confusion_matrix, precision_score, recall_score, f1_score

def distill_r2plus1d_blocks(teacher_model, student_model, video_input):
    with torch.no_grad():
        teacher_features1 = teacher_model.video_net.layer2(teacher_model.video_net.layer1(video_input))
        teacher_features2 = teacher_model.video_net.layer4(teacher_model.video_net.layer3(teacher_features1))

    student_features1 = student_model.layer1(video_input)
    student_features2 = student_model.layer2(student_features1)

    loss_r2p1d1 = F.mse_loss(student_features1, teacher_features1)
    loss_r2p1d2 = F.mse_loss(student_features2, teacher_features2)

    total_loss_r2p1d = (loss_r2p1d1 + loss_r2p1d2) / 2

    return total_loss_r2p1d, teacher_features2, student_features2

def distill_fc_sequence(teacher_model, student_model, video_features, sensor_features, temperature=3.0):
    with torch.no_grad():
        combined_teacher_features = torch.cat([video_features, sensor_features], dim=1) 
        teacher_output = teacher_model.fc_sequence(combined_teacher_features)

    student_output = student_model.fc_sequence(video_features)

    teacher_prob = torch.sigmoid(teacher_output / temperature)  
    student_prob = torch.sigmoid(student_output / temperature) 

    teacher_dist = torch.cat([1 - teacher_prob, teacher_prob], dim=1)   
    student_dist = torch.cat([1 - student_prob, student_prob], dim=1) 

    epsilon = 1e-7
    student_log_dist = (student_dist + epsilon).log()
    loss_fc = (temperature**2) * F.kl_div(
        student_log_dist,  # log_prob of student
        teacher_dist,       # prob of teacher
        reduction="batchmean"
    )
    return loss_fc, student_output

def train(
    teacher_model, 
    student_model, 
    train_loader, 
    val_loader, 
    optimizer, 
    num_epochs=10, 
    alpha=0.7, 
    beta=0.3, 
    gamma=1.0,   # 추가된 gamma 가중치
    temperature=3.0, 
    device="cuda"
):
    teacher_model.eval()
    student_model.to(device)

    # BCEWithLogitsLoss 사용 예시 (이진 분류 가정)
    criterion = nn.BCEWithLogitsLoss()

    for epoch in range(num_epochs):
        student_model.train()
        total_loss = 0.0
        correct = 0
        total = 0

        pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}")

        for video_input, sensor_input, labels in pbar:
            video_input, sensor_input, labels = video_input.to(device), sensor_input.to(device), labels.unsqueeze(1).float().to(DEVICE)
            optimizer.zero_grad()

            # 교사 모델 특징 추출
            video_features = teacher_model.video_net(video_input)
            sensor_features = teacher_model.sensor_net(sensor_input) 

            # R2Plus1D 블록 지식 증류 손실
            loss_r2p1d, teacher_features, student_features = distill_r2plus1d_blocks(
                teacher_model, student_model, video_input
            )

            student_features = student_model.pool(student_features)
            student_features = student_features.view(student_features.size(0), -1)

            # fc_sequence 지식 증류 손실
            loss_fc, student_output = distill_fc_sequence(
                teacher_model, student_model, video_features, sensor_features, temperature
            )

            # 분류 손실
            classification_loss = criterion(student_output, labels)

            # 총 손실 계산: alpha, beta, gamma 가중치 적용
            loss = alpha * loss_r2p1d + beta * loss_fc + gamma * classification_loss
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

            pbar.set_postfix(loss=loss.item())

            # 정확도 계산
            # 이진 분류라면 다음과 같이 sigmoid 후 thresholding 필요
            # 예: predicted = (torch.sigmoid(student_output) > 0.5).long()
            # 여기서는 다중 클래스 가정으로 torch.max 사용
            predicted = (torch.sigmoid(student_output) > 0.5).long()
            correct += (predicted == labels).sum().item()
            total += labels.size(0)

        train_accuracy = correct / total * 100
        print(f"Training Loss: {total_loss:.4f}, Training Accuracy: {train_accuracy:.2f}%")

        # Validation phase
        student_model.eval()
        val_loss = 0.0
        correct = 0
        total = 0

        all_labels = []
        all_preds = []

        with torch.no_grad():
            for video_input, _, labels in val_loader:
                video_input, labels = video_input.to(device), labels.unsqueeze(1).float().to(DEVICE)
                
                student_output = student_model(video_input)
                
                batch_val_loss = criterion(student_output, labels).item()
                val_loss += batch_val_loss

                predicted = (torch.sigmoid(student_output) > 0.5).long()
                correct += (predicted == labels).sum().item()
                total += labels.size(0)

                all_labels.extend(labels.cpu().numpy().flatten())
                all_preds.extend(predicted.cpu().numpy().flatten())

        val_accuracy = correct / total * 100
        print(f"Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_accuracy:.2f}%")
        cm = confusion_matrix(all_labels, all_preds)

        # Precision, Recall, F1-Score 계산
        precision = precision_score(all_labels, all_preds)
        recall = recall_score(all_labels, all_preds)
        f1 = f1_score(all_labels, all_preds)

        # 결과 출력
        print("Confusion Matrix:\n", cm)
        print(f"Precision: {precision:.2f}, Recall: {recall:.2f}, F1-Score: {f1:.2f}")
        torch.cuda.empty_cache()


In [8]:
train(
    teacher_model=teacher_model,
    student_model=student_model,
    train_loader=train_loader,
    val_loader=val_loader,
    optimizer=optimizer,
    num_epochs=10,
    alpha=0.005,
    beta=0.005,
    gamma = 0.99,
    temperature=5.0,
    device=DEVICE
)

Epoch 1/10: 100%|██████████| 3022/3022 [26:23<00:00,  1.91it/s, loss=1.49] 

Training Loss: 2907.7771, Training Accuracy: 97.99%





Validation Loss: 213.6727, Validation Accuracy: 80.15%
Confusion Matrix:
 [[ 385  183]
 [ 268 1436]]
Precision: 0.89, Recall: 0.84, F1-Score: 0.86


Epoch 2/10: 100%|██████████| 3022/3022 [26:20<00:00,  1.91it/s, loss=1.22] 

Training Loss: 2700.9620, Training Accuracy: 99.28%





Validation Loss: 194.9082, Validation Accuracy: 75.57%
Confusion Matrix:
 [[  14  554]
 [   1 1703]]
Precision: 0.75, Recall: 1.00, F1-Score: 0.86


Epoch 3/10:   2%|▏         | 60/3022 [00:40<33:15,  1.48it/s, loss=0.64]  


KeyboardInterrupt: 

In [None]:
model_save_path = 'model.pth'
torch.save({
    'model_state_dict': student_model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
}, model_save_path)
print(f"Model and parameters saved to {model_save_path}")

NameError: name 'torch' is not defined