## 画像のエンコード

In [None]:
import os
import torch
import gc
from PIL import Image
from tqdm import tqdm
from transformers import CLIPProcessor, CLIPModel

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
model.eval()

processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

def load_image_paths(file_path):
    with open(file_path, 'r') as f:
        image_paths = f.read().splitlines()
    return image_paths

def correct_paths(image_paths):
    corrected_paths = []
    for path in image_paths:
        if '/' not in path:
            parts = path.split('_')
            if len(parts) > 1:
                parent_dir = '_'.join(parts[:-1]) 
                path = os.path.join(parent_dir, path)
        corrected_paths.append(path)
    return corrected_paths

def process_images_and_texts(image_paths):
    corrected_paths = correct_paths(image_paths)
    texts = [path.replace("\\", "/").split('/')[0] for path in corrected_paths]
    images = [Image.open(os.path.join('images\\Images', path.replace("/", "\\"))).convert("RGB") for path in corrected_paths]
    return images, texts

def encode_images(images, texts=None):
    inputs = processor(images=images, text=texts, return_tensors="pt", padding=True).to(device) if texts else processor(images=images, return_tensors="pt", padding=True).to(device)
    with torch.no_grad():
        outputs = model(**inputs).image_embeds if texts else model.get_image_features(**inputs)
    return outputs.cpu()

def save_encoded_features(image_paths, batch_size, save_path, include_texts=False):
    encoded_features = []
    num_batches = len(image_paths) // batch_size + int(len(image_paths) % batch_size != 0)

    for i in tqdm(range(num_batches)):
        batch_paths = image_paths[i*batch_size:(i+1)*batch_size]
        images, texts = process_images_and_texts(batch_paths)
        outputs = encode_images(images, texts if include_texts else None)
        encoded_features.append(outputs)
        torch.cuda.empty_cache()
        gc.collect()

    encoded_features = torch.cat(encoded_features)
    print(f"Size: {encoded_features.shape}")
    torch.save(encoded_features, save_path)

# Train set encoding
train_image_paths = load_image_paths('data\\train_image_paths.txt')
save_encoded_features(train_image_paths, batch_size=512, save_path='data\\train_image_text_encoded.pt', include_texts=True)

# Validation set encoding
val_image_paths = load_image_paths('data\\val_image_paths.txt')
save_encoded_features(val_image_paths, batch_size=512, save_path='data\\val_image_encoded.pt', include_texts=True)

## 事前学習

In [None]:
## 事前学習

s, sys
import numpy as np
import torch
import torch.nn.functional as F
from torchmetrics import Accuracy
from omegaconf import DictConfig
from termcolor import cprint
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GATConv
from torchinfo import summary
from einops import rearrange

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class SubjectLayers(nn.Module):
    """Per subject linear layer.
    The code is modified from
    https://github.com/facebookresearch/brainmagick/blob/1aa77d2dbd801b4901aeadfd8606d26a89ee7e3e/bm/models/common.py#L45
    """
    def __init__(self, in_channels: int, out_channels: int, n_subjects: int, init_id: bool = False):
        super().__init__()
        self.weights = nn.Parameter(torch.randn(n_subjects, in_channels, out_channels))
        if init_id:
            assert in_channels == out_channels
            self.weights.data[:] = torch.eye(in_channels)[None]
        self.weights.data *= 1 / in_channels**0.5

    def forward(self, x, subjects):
        _, C, D = self.weights.shape
        weights = self.weights.gather(0, subjects.view(-1, 1, 1).expand(-1, C, D))
        # print(x.shape, weights.shape)
        return torch.einsum("bct,bcd->bdt", x, weights)

class ResidualAdd(nn.Module):
    """
        The code is modified from https://github.com/eeyhsong/NICE-EEG
    """
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def forward(self, x, **kwargs):
        res = x
        x = self.fn(x, **kwargs)
        x += res
        return x


class ProjMeg(nn.Sequential):
    """
        The code is modified from https://github.com/eeyhsong/NICE-EEG
    """
    def __init__(self, emb_dim=512, out_dim=512, drop_proj=0.5):
        super().__init__(
            nn.Linear(emb_dim, out_dim),
            ResidualAdd(nn.Sequential(
                nn.GELU(),
                nn.Linear(out_dim, out_dim),
                nn.Dropout(drop_proj),
            )),
            nn.LayerNorm(out_dim),
        )


class ProjImg(nn.Sequential):
    """
        The code is modified from https://github.com/eeyhsong/NICE-EEG
    """
    def __init__(self, emb_dim=512, out_dim=512, drop_proj=0.3):
        super().__init__(
            nn.Linear(emb_dim, out_dim),
            ResidualAdd(nn.Sequential(
                nn.GELU(),
                nn.Linear(out_dim, out_dim),
                nn.Dropout(drop_proj),
            )),
            nn.LayerNorm(out_dim),
        )

class ChannelAttention(nn.Module):
    """
        The code is modified from https://github.com/eeyhsong/NICE-EEG
    """
    def __init__(self, sequence_num=250, inter=30, num_channel=271):
        super(ChannelAttention, self).__init__()
        self.sequence_num = sequence_num
        self.inter = inter
        self.extract_sequence = int(self.sequence_num / self.inter) 
        self.query = nn.Sequential(
            nn.Linear(num_channel, num_channel),
            nn.LayerNorm(num_channel),
            nn.Dropout(0.3)
        )
        self.key = nn.Sequential(
            nn.Linear(num_channel, num_channel),
            nn.LayerNorm(num_channel),
            nn.Dropout(0.3)
        )
        self.projection = nn.Sequential(
            nn.Linear(num_channel, num_channel),
            nn.LayerNorm(num_channel),
            nn.Dropout(0.3),
        )
        self.drop_out = nn.Dropout(0)
        self.pooling = nn.AvgPool2d(kernel_size=(1, self.inter), stride=(1, self.inter))
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0.0)

    def forward(self, x):
        # print(x.shape)
        temp = rearrange(x, 'b o c s -> b o s c')
        # print(temp.shape)
        query = self.query(temp)
        key = self.key(temp)
        # print(query.shape, key.shape)

        scaling = self.extract_sequence ** 0.5

        attn_scores = torch.einsum('b o s c, b o s m -> b o c m', query, key) / scaling
        attn_scores = F.softmax(attn_scores, dim=-1)
        attn_scores = self.drop_out(attn_scores)
        # print(attn_scores.shape)
        out = torch.einsum('b o c s, b o c m -> b o c s', x, attn_scores)
        # print(out.shape)
        out = rearrange(out, 'b o c s -> b o s c')
        out = self.projection(out)
        out = rearrange(out, 'b o s c -> b o c s')
        return out



class EEG_GAT(nn.Module):
    """
        The code is modified from https://github.com/eeyhsong/NICE-EEG
    """
    def __init__(self, in_channels=281, out_channels=281):
        super(EEG_GAT, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.conv1 = GATConv(in_channels=in_channels, out_channels=out_channels, heads=1)

        self.num_channels = 271
        # Create a list of tuples representing all possible edges between channels
        self.edge_index_list = torch.Tensor([(i, j) for i in range(self.num_channels) for j in range(self.num_channels) if i != j]).to(device)
        # Convert the list of tuples to a tensor
        self.edge_index = torch.tensor(self.edge_index_list, dtype=torch.long).t().contiguous().to(device)

    def forward(self, x):
        batch_size, _, num_channels, num_features = x.size()
        
        # Reshape x for GATConv
        x = x.squeeze(1)  # [10, 1, 271, 281] -> [10, 271, 281]
        # print(x.shape)
        # x = x.permute(0, 2, 1)  # [10, 271, 281] -> [10, 281, 271]
        # print(x.shape)
        x = x.reshape(batch_size * num_channels, num_features)  # [10, 271, 281] -> [2710, 281]
        # print(x.shape)
        # Apply GATConv
        out = self.conv1(x, self.edge_index)  # [2810, 281] -> [2810, 281]
        # print(out.shape)
        # Reshape the output back to the original form
        out = out.view(batch_size, num_channels, num_features)  # [2810, 271] -> [10, 281, 271]
        # print(out.shape)
        # out = out.permute(0, 2, 1)  # [10, 281, 271] -> [10, 271, 281]
        out = out.unsqueeze(1)  # [10, 271, 281] -> [10, 1, 271, 281]
        # print(out.shape)
        return out


class MEGEncoder(nn.Module):
    """
    The code is modified from https://github.com/ChiShengChen/MUSE_EEG
    """
    def __init__(self, n_channel=271, len_sec=281, emb_dim=40, out_dim=512, m1=25, m2=51, s=5, n_subjects=4):
        super(MEGEncoder, self).__init__()
        self.sa = ResidualAdd(nn.Sequential(
            EEG_GAT(),
            nn.Dropout(0.3)
        ))
        # self.temporal_conv = nn.Conv2d(1, emb_dim, kernel_size=(1, m1), stride=(1, 1))
        # self.pool = nn.AvgPool2d(kernel_size=(1, m2), stride=(1, s))
        # self.spatial_conv = nn.Conv2d(emb_dim, emb_dim, kernel_size=(n_channel, 1), stride=(1, 1))
        # self.bn1 = nn.BatchNorm2d(emb_dim)
        # self.bn2 = nn.BatchNorm2d(emb_dim)
        self.attn_1 = nn.MultiheadAttention(embed_dim=emb_dim, num_heads=4)
        self.attn_2 = nn.MultiheadAttention(embed_dim=emb_dim, num_heads=4)
        self.attn_3 = nn.MultiheadAttention(embed_dim=emb_dim, num_heads=8, dropout=0.75)

        self.flatten = nn.Flatten()
        self.subject_layers = SubjectLayers(in_channels=emb_dim, out_channels=emb_dim, n_subjects=n_subjects)

        self.tsconv = nn.Sequential(
            nn.Conv2d(1, emb_dim, (1, m1), (1, 1)),
            nn.AvgPool2d((1, m2), (1, s)),
            nn.BatchNorm2d(emb_dim),
            nn.ELU(),
            nn.Conv2d(emb_dim, emb_dim, (n_channel, 1), (1, 1)),
            nn.BatchNorm2d(emb_dim),
            nn.ELU(),
            nn.Dropout(0.5),
        )

        self.stconv = nn.Sequential(
            nn.Conv2d(1, emb_dim, (n_channel, 1), (1, 1)),
            nn.BatchNorm2d(emb_dim),
            nn.ELU(),
            nn.Conv2d(emb_dim, emb_dim, (1, m1), (1, 1)),
            nn.AvgPool2d((1, m2), (1, s)),
            nn.BatchNorm2d(emb_dim),
            nn.ELU(),
            nn.Dropout(0.5),
        )

        fc_in_dim = 2 * emb_dim * ((len_sec - m1 - m2 + 1) // s + 1)
        # print(fc_in_dim)
        self.fc = nn.Linear(fc_in_dim, out_dim) 
        self.do = nn.Dropout(0.5)

        self.norm_1 = nn.LayerNorm(emb_dim) 
        self.norm_2 = nn.LayerNorm(emb_dim)
        self.norm_3 = nn.LayerNorm(emb_dim)

    def forward(self, x, subjects):
        x = self.sa(x)
        # print(x.shape)
        # x = self.temporal_conv(x)
        # # print(x.shape)
        # x = self.pool(x)
        # # print(x.shape)
        # x = self.bn1(x)
        # # print(x.shape)
        # x = F.elu(x)
        # # print(x.shape)
        # x = self.spatial_conv(x)
        # # print(x.shape)
        # x = self.bn2(x)
        # # print(x.shape)
        # x = F.elu(x)
        # # print(x.shape)
        # x = self.do(x)
        ts = self.tsconv(x).squeeze(2).permute(2, 0, 1)
        # print(x.shape)
        st = self.stconv(x).squeeze(2).permute(2, 0, 1)

        attn_ts, _ = self.attn_1(ts, ts, ts)
        attn_st, _ = self.attn_2(st, st, st)

        bf_ts_features = self.norm_1(attn_ts + ts)
        bf_st_features = self.norm_2(attn_st + st)

        combined_features = torch.cat((bf_ts_features, bf_st_features), dim=0)
        attn_combined, _ = self.attn_3(combined_features, combined_features, combined_features)
        x = self.norm_3(attn_combined + combined_features).permute(1, 2, 0)

        # Apply the subject layers
        x = self.subject_layers(x, subjects)

        # print(x.shape)
        x = self.flatten(x)
        # print(x.shape)
        x = self.fc(x)

        return x

class ContrastiveModel(nn.Module):
    """
        The code is modified from https://github.com/eeyhsong/NICE-EEG
    """
    def __init__(self, meg_encoder, proj_meg, proj_img):
        super(ContrastiveModel, self).__init__()
        self.meg_encoder = meg_encoder
        self.proj_meg = proj_meg
        self.proj_img = proj_img
        self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))

    def forward(self, X, image_features, subjects):
        meg_features = self.meg_encoder(X, subjects)
        eeg_features = self.proj_meg(meg_features)
        img_features = self.proj_img(image_features)

        eeg_features = eeg_features / eeg_features.norm(dim=1, keepdim=True)
        img_features = img_features / img_features.norm(dim=1, keepdim=True)

        logits = self.logit_scale.exp() * torch.matmul(eeg_features, img_features.t())
        return logits
    
# meg_encoder = MEGEncoder(out_dim=512, m1=19, m2=10, s=4).to("cpu")
# input_data = (torch.randn(10, 1, 271, 281), torch.randint(0, 4, (10,)))
# model_summary = summary(meg_encoder, input_data=input_data)
# model_summary

In [None]:
class ThingsMEGDataset(torch.utils.data.Dataset):
    def __init__(self, split: str, data_dir: str = "data", baseline_frames: int = 50) -> None:
        super().__init__()
        
        assert split in ["train", "val", "test"], f"Invalid split: {split}"
        self.split = split
        self.num_classes = 1854
        self.baseline_frames = baseline_frames
        
        self.X = torch.load(os.path.join(data_dir, f"{split}_X.pt"))
        self.subject_idxs = torch.load(os.path.join(data_dir, f"{split}_subject_idxs.pt"))
        
        if split in ["train", "val"]:
            self.y = torch.load(os.path.join(data_dir, f"{split}_y.pt"))
            self.img_features = torch.load(os.path.join(data_dir, f"{split}_image_encoded.pt"))
            assert len(torch.unique(self.y)) == self.num_classes, "Number of classes do not match."

        self.apply_baseline_correction()

    def apply_baseline_correction(self):
        baseline = self.X[:, :, :self.baseline_frames].mean(dim=2, keepdim=True)
        self.X = self.X - baseline

    def __len__(self) -> int:
        return len(self.X)

    def __getitem__(self, i):
        if hasattr(self, "y"):
            return self.X[i], self.img_features[i], self.y[i], self.subject_idxs[i]
        else:
            return self.X[i], self.subject_idxs[i]
        
    @property
    def num_channels(self) -> int:
        return self.X.shape[1]
    
    @property
    def seq_len(self) -> int:
        return self.X.shape[2]

loader_args = {"batch_size": 128, "num_workers": 0}
data_dir = "data"
train_set = ThingsMEGDataset("train", data_dir)
train_loader = torch.utils.data.DataLoader(train_set, shuffle=True, **loader_args)
val_set = ThingsMEGDataset("val", data_dir)
val_loader = torch.utils.data.DataLoader(val_set, shuffle=False, **loader_args)
test_set = ThingsMEGDataset("test", data_dir)
test_loader = torch.utils.data.DataLoader(test_set, shuffle=False, **loader_args)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

max_val_acc = 0
accuracy = Accuracy(
    task="multiclass", num_classes=train_set.num_classes, top_k=10
).to(device)

In [None]:
from tqdm import tqdm
import numpy as np
import torch
import torch.nn.functional as F

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def accuracy_contrastive(logits, labels):
    preds = torch.argmax(logits, dim=1)
    correct = torch.sum(preds == labels).item()
    accuracy = correct / labels.size(0)
    return accuracy

def train_epoch(model, dataloader, optimizer):
    model.train()
    train_loss, train_acc = [], []

    for X, image_features, y, subject_idxs in tqdm(dataloader, desc="Train"):
        X, image_features, subject_idxs = X.to(device), image_features.to(device), subject_idxs.to(device)
        X = X.unsqueeze(1)

        logits = model(X, image_features, subject_idxs)
        labels = torch.arange(logits.shape[0]).to(device)

        loss_e = F.cross_entropy(logits, labels)
        loss_i = F.cross_entropy(logits.t(), labels)
        loss = (loss_e + loss_i) / 2

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        acc = accuracy_contrastive(logits, labels)
        train_acc.append(acc)
        train_loss.append(loss.item())

    return np.mean(train_loss), np.mean(train_acc)

def validate_epoch(model, dataloader):
    model.eval()
    val_loss, val_acc = [], []

    with torch.no_grad():
        for X, image_features, y, subject_idxs in tqdm(dataloader, desc="Validation"):
            X, image_features, subject_idxs = X.to(device), image_features.to(device), subject_idxs.to(device)
            X = X.unsqueeze(1)

            logits = model(X, image_features, subject_idxs)
            labels = torch.arange(logits.shape[0]).to(device)

            loss_e = F.cross_entropy(logits, labels)
            loss_i = F.cross_entropy(logits.t(), labels)
            loss = (loss_e + loss_i) / 2

            acc = accuracy_contrastive(logits, labels)
            val_acc.append(acc)
            val_loss.append(loss.item())

    return np.mean(val_loss), np.mean(val_acc)

def run_training(model, train_loader, val_loader, optimizer, epochs, save_path):
    max_val_acc = 0

    for epoch in range(epochs):
        print(f"Epoch {epoch+1}/{epochs}")

        train_loss, train_acc = train_epoch(model, train_loader, optimizer)
        val_loss, val_acc = validate_epoch(model, val_loader)

        print(f"Epoch {epoch+1}/{epochs} | train loss: {train_loss:.3f} | train acc: {train_acc:.3f} | val loss: {val_loss:.3f} | val acc: {val_acc:.3f}")

        if val_acc > max_val_acc:
            cprint("New best.", "cyan")
            max_val_acc = val_acc
            torch.save(model.state_dict(), save_path)

emb_dim = 512
emb_meg = 768
meg_encoder = MEGEncoder(out_dim=emb_meg).to(device)
proj_meg = ProjMeg(emb_dim=emb_meg, out_dim=emb_meg).to(device)
proj_img = ProjImg(emb_dim=emb_dim, out_dim=emb_meg).to(device)
model = ContrastiveModel(meg_encoder, proj_meg, proj_img).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
run_training(model, train_loader, val_loader, optimizer, epochs=20, save_path="best_attn.pth")

optimizer = torch.optim.AdamW(model.parameters(), lr=2e-4, betas=(0.5, 0.999))
run_training(model, train_loader, val_loader, optimizer, epochs=3, save_path="best_attn_adamw.pth")

## 識別器の学習


In [None]:

class Classifier(nn.Sequential):
    def __init__(self, emb_dim=512, out_dim=512, drop=0.3):
        super().__init__(
            nn.Linear(emb_dim, out_dim),
            ResidualAdd(nn.Sequential(
                nn.GELU(),
                nn.Linear(out_dim, out_dim),
                nn.Dropout(drop),
            )),
            nn.LayerNorm(out_dim),
        )
        
import torch
import torch.nn.functional as F
from tqdm import tqdm
import numpy as np

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def accuracy_contrastive(logits, labels):
    preds = torch.argmax(logits, dim=1)
    correct = torch.sum(preds == labels).item()
    accuracy = correct / labels.size(0)
    return accuracy

def train_one_epoch(model, classifier, optimizer, data_loader, epoch, epochs, is_train=True):
    if is_train:
        model.train()
        classifier.train()
        desc = "Train"
    else:
        model.eval()
        classifier.eval()
        desc = "Validation"
    
    total_loss, total_acc = 0, 0
    
    with torch.set_grad_enabled(is_train):
        for X, _, y, subject_idxs in tqdm(data_loader, desc=f"{desc} Epoch {epoch+1}/{epochs}"):
            X, y = X.to(device), y.to(device)
            subject_idxs = subject_idxs.to(device)
            X = X.unsqueeze(1) 

            features = model.meg_encoder(X, subject_idxs)
            features = model.proj_meg(features)
            features = features / features.norm(dim=1, keepdim=True)

            logits = classifier(features)
            loss = F.cross_entropy(logits, y)

            if is_train:
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

            acc = accuracy_multi(logits, y)

            total_loss += loss.item()
            total_acc += acc
    
    total_loss /= len(data_loader)
    total_acc /= len(data_loader)
    
    return total_loss, total_acc

def save_model_state(model, classifier, file_name):
    torch.save({
        'meg_encoder': model.meg_encoder.state_dict(),
        'proj_meg': model.proj_meg.state_dict(),
        'classifier': classifier.state_dict()
    }, file_name)

def load_model_state(model, classifier, file_name):
    checkpoint = torch.load(file_name)
    model.meg_encoder.load_state_dict(checkpoint['meg_encoder'])
    model.proj_meg.load_state_dict(checkpoint['proj_meg'])
    classifier.load_state_dict(checkpoint['classifier'])


emb_dim = 512
emb_meg = 768

meg_encoder_pred = MEGEncoder(out_dim=emb_meg).to(device)
proj_meg_pred = ProjMeg(emb_dim=emb_meg, out_dim=emb_meg).to(device)
proj_img_pred = ProjImg(emb_dim=512, out_dim=emb_meg).to(device)
contrastive_model_pred = ContrastiveModel(meg_encoder_pred, proj_meg_pred, proj_img_pred).to(device)
contrastive_model_pred.load_state_dict(torch.load('best_attn_adamw_07180006.pth'))

classifier = Classifier(emb_dim=emb_meg, out_dim=test_set.num_classes).to(device)
optimizer = torch.optim.Adam(list(meg_encoder_pred.parameters()) + list(proj_meg_pred.parameters()) + list(classifier.parameters()), lr=1e-3)

max_val_acc = 0
accuracy_multi = Accuracy(task="multiclass", num_classes=train_set.num_classes, top_k=10).to(device)

epochs = 15
best_val_acc = 0
for epoch in range(epochs):
    train_loss, train_acc = train_one_epoch(contrastive_model_pred, classifier, optimizer, train_loader, epoch, epochs, is_train=True)
    val_loss, val_acc = train_one_epoch(contrastive_model_pred, classifier, optimizer, val_loader, epoch, epochs, is_train=False)

    print(f"Epoch {epoch+1}/{epochs} | train loss: {train_loss:.3f} | train acc: {train_acc:.3f} | val loss: {val_loss:.3f} | val acc: {val_acc:.3f}")

    if val_acc > best_val_acc:
        print("New best model saved.")
        best_val_acc = val_acc
        save_model_state(contrastive_model_pred, classifier, 'best_ft_model_attn.pth')

load_model_state(contrastive_model_pred, classifier, 'best_ft_model_attn_07180020.pth')

optimizer = torch.optim.AdamW(list(meg_encoder_pred.parameters()) + list(proj_meg_pred.parameters()) + list(classifier.parameters()), lr=2e-4, betas=(0.5, 0.999))

epochs = 15
best_val_acc = 0
for epoch in range(epochs):
    train_loss, train_acc = train_one_epoch(contrastive_model_pred, classifier, optimizer, train_loader, epoch, epochs, is_train=True)
    val_loss, val_acc = train_one_epoch(contrastive_model_pred, classifier, optimizer, val_loader, epoch, epochs, is_train=False)

    print(f"Epoch {epoch+1}/{epochs} | train loss: {train_loss:.3f} | train acc: {train_acc:.3f} | val loss: {val_loss:.3f} | val acc: {val_acc:.3f}")

    if val_acc > best_val_acc:
        print("New best model saved.")
        best_val_acc = val_acc
        save_model_state(contrastive_model_pred, classifier, 'best_ft_model_attn_adamw.pth')


## submissionの生成


In [None]:

emb_dim = 512
emb_meg = 768
meg_encoder_pred = MEGEncoder(out_dim=emb_meg).to(device)
proj_meg_pred = ProjMeg(emb_dim=emb_meg, out_dim=emb_meg).to(device)
classifier = Classifier(emb_dim=emb_meg, out_dim=test_set.num_classes).to(device)

checkpoint = torch.load('best_ft_model_emb_adamw.pth')

meg_encoder_pred.load_state_dict(checkpoint['meg_encoder'])
proj_meg_pred.load_state_dict(checkpoint['proj_meg'])
classifier.load_state_dict(checkpoint['classifier'])

preds = [] 
meg_encoder_pred.eval()
proj_meg_pred.eval()
classifier.eval()
for X, subject_idxs in tqdm(test_loader, desc="Validation"):        
    X = X.to(device)
    subject_idxs = subject_idxs.to(device)
    X = X.unsqueeze(1) 

    features = meg_encoder_pred(X, subject_idxs)
    features = proj_meg_pred(features)
    features = features / features.norm(dim=1, keepdim=True)

    logits = classifier(features)

    preds.append(logits.detach().cpu())
    
preds = torch.cat(preds, dim=0).numpy()
np.save(os.path.join("data", "submission"), preds)
cprint(f"Submission {preds.shape} saved at data", "cyan")

