# Unetのセグメンテーションモデル

In [None]:
pip install segmentation_models_pytorch

In [None]:
pip install pydicom

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import cv2
import pydicom
import numpy as np
import os
import glob
from tqdm import tqdm
import gc

import torchvision
import torch
import torch.nn as nn
from torch.utils.data import Dataset
from fastai.vision.all import *
import segmentation_models_pytorch as smp

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
CV = 5
SEED = 777
fold = 1
PATCH_SIZE = 512
patch_size = 64
TH = .5
SEG_TRAIN = False
SEG = {
    'BS':16,
    'LR':5e-4,
    'EPOCHS':10
}
INF = {
    'BS':64,
    'LR':1e-5,
    'EPOCHS':10
}

In [None]:
!unzip /content/drive/MyDrive/rsna-2024-lumbar-spine-degenerative-classification.zip -d /content #RSNA2024を今のディレクトリに展開する

In [None]:
train = pd.read_csv('/content/train.csv')
train.tail()

In [None]:
diagnosis = list(filter(lambda x: x.find('foraminal') > -1, train.columns))
train = train[train[diagnosis].isnull().values.sum(1)==0].reset_index(drop=True)
train.tail()

In [None]:
train2=train[['study_id']+diagnosis]

In [None]:
labels = {
    'Normal/Mild':0,
    'Moderate':1,
    'Severe':2
}

In [None]:
df_meta_f = pd.read_csv('/content/train_series_descriptions.csv')
df_meta_f.tail()

In [None]:
df_coor = pd.read_csv('/content/train_label_coordinates.csv')
df_coor.tail()

In [None]:
RF = df_coor[df_coor['condition']=='Spinal Canal Stenosis'][[
    'study_id',
    'series_id',
    'instance_number',
    'level',
    'x',
    'y'
]].sort_values([
    'study_id',
    'series_id',
    'level'
])[[
    'study_id',
    'series_id',
    'instance_number',
    'level',
    'x',
    'y'
]].drop_duplicates()
RF.tail()

In [None]:
centers = {}
for i in range(len(RF)):
    row = RF.iloc[i]
    centers[row['study_id']]={}
for i in range(len(RF)):
    row = RF.iloc[i]
    centers[row['study_id']][row['series_id']]={'L1/L2':[],'L2/L3':[],'L3/L4':[],'L4/L5':[],'L5/S1':[]}
for i in range(len(RF)):
    row = RF.iloc[i]
    centers[row['study_id']][row['series_id']][row['level']].append([row['x'],row['y']])

In [None]:
coordinates = np.zeros((len(RF),10))
coordinates[:] = np.nan
for i in range(len(RF)):
    row = RF.iloc[i]
    for level in centers[row['study_id']][row['series_id']]:
        if len(centers[row['study_id']][row['series_id']][level]) > 0:
            center = np.array(centers[row['study_id']][row['series_id']][level]).mean(0)
            coordinates[
                i,
                {'L1/L2':0, 'L2/L3':2, 'L3/L4':4, 'L4/L5':6, 'L5/S1':8}[level]:{'L1/L2':0, 'L2/L3':2, 'L3/L4':4, 'L4/L5':6, 'L5/S1':8}[level]+2
            ] = center

In [None]:
RF = RF[[
    'study_id',
    'series_id',
    'instance_number',
    'x',
    'y'
]]
RF[[
    'x_L1L2',
    'y_L1L2',
    'x_L2L3',
    'y_L2L3',
    'x_L3L4',
    'y_L3L4',
    'x_L4L5',
    'y_L4L5',
    'x_L5S1',
    'y_L5S1',
]] = coordinates
RF = RF.drop(columns=['x','y']).drop_duplicates().reset_index(drop=True)
RF.tail()

In [None]:
RF = RF[RF[[
    'x_L1L2',
    'y_L1L2',
    'x_L2L3',
    'y_L2L3',
    'x_L3L4',
    'y_L3L4',
    'x_L4L5',
    'y_L4L5',
    'x_L5S1',
    'y_L5S1',
]].isnull().values.sum(1)==0].reset_index(drop=True)
RF.tail()

In [None]:
diagnosis = list(filter(lambda x: x.find('spinal') > -1, train.columns))
RF = RF.merge(train[['study_id']+diagnosis], left_on='study_id', right_on='study_id')
RF.tail()

In [None]:
labels = {
    'Normal/Mild':0,
    'Moderate':1,
    'Severe':2
}

coor = [
    'x_L1L2',
    'y_L1L2',
    'x_L2L3',
    'y_L2L3',
    'x_L3L4',
    'y_L3L4',
    'x_L4L5',
    'y_L4L5',
    'x_L5S1',
    'y_L5S1',
]

In [None]:
def seed_everything(seed):
    random.seed(seed)
    np.random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

In [None]:
def augment_image_and_centers(image,centers,alpha):
    '''
    # Randomly flip the image horizontally.
    if random.random() > .5:
      if random.random() > 1 - alpha:
        image = image.flip(-1)
        centers[:,0] = PATCH_SIZE - centers[:,0]
    # Randomly flip the image vertically.
    if random.random() > 0.5:
      if random.random() > 1 - alpha:
        image = image.flip(-2)
        centers[:,1] = PATCH_SIZE - centers[:,1]

    if random.random() > 1 - alpha:
      if random.random() > .5:
    #   Randomly flip the image
    #   Wich axis?
         axis = np.random.randint(2)
        image = image.flip(axis+1)
        centers[:,-1-axis] = PATCH_SIZE - centers[:,-1-axis]
    '''
#   Randomly rotate the image.
    angle = torch.as_tensor(random.uniform(-180, 180)*alpha)
    image = torchvision.transforms.functional.rotate(image,angle.item())
#   https://discuss.pytorch.org/t/rotation-matrix/128260
    angle = angle*math.pi/180
    s = torch.sin(angle)
    c = torch.cos(angle)
    rot = torch.stack([
        torch.stack([c, s]),
        torch.stack([-s, c])
      ])
    centers = ((centers.cpu() - PATCH_SIZE//2) @ rot) + PATCH_SIZE//2
    return image,centers

In [None]:
class T1Dataset(Dataset):
    def __init__(self, df, VALID=False, alpha=0):
        self.data = df
        self.VALID = VALID
        self.alpha = alpha

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        row = self.data.iloc[index]

        centers = torch.as_tensor([x for x in row[coor]]).view(5,2).float()

        sample = '/content/train_images/'
        sample = sample+str(row['study_id'])+'/'+str(row['series_id'])+'/'+str(row['instance_number'])+'.dcm'
        image = pydicom.dcmread(sample).pixel_array
        H,W = image.shape
#       By plane resizing I've been distorting the proportions
        if H > W:
            d = W
            if not self.VALID:
                h = int((H - d)*(.5 + self.alpha*(.5 - np.random.rand())))
            else:
                h = (H - d)//2
            image = image[h:h+d]
            centers[:,1] -= h
            H = W
        elif H < W:
            d = H
            if not self.VALID:
                w = int((W - d)*(.5 + self.alpha*(.5 - np.random.rand())))
            else:
                w = (W - d)//2
            image = image[:,w:w+d]
            centers[:,0] -= w
            W = H
        image = cv2.resize(image,(PATCH_SIZE,PATCH_SIZE))
        image = torch.as_tensor(image/np.max(image)).unsqueeze(0).float()

        label = label = torch.as_tensor(1)

        centers[:,0] = centers[:,0]*PATCH_SIZE/W
        centers[:,1] = centers[:,1]*PATCH_SIZE/H

        if not self.VALID: image,centers = augment_image_and_centers(image,centers,self.alpha)
        return image.to(device),[label.to(device),centers.to(device)]

In [None]:
class myUNet3(nn.Module):
    def __init__(self):
        super(myUNet3, self).__init__()

        self.UNet = smp.Unet(
            encoder_name="resnet18",
            classes=5,
            in_channels=1
        ).to(device)

    def forward(self,X):
        x = self.UNet(X)
#       MinMaxScaling along the class plane to generate a heatmap
        min_values = x.view(-1,5,PATCH_SIZE*PATCH_SIZE).min(-1)[0].view(-1,5,1,1) # Bug, I've been MinMaxScaling with the wrong values
        max_values = x.view(-1,5,PATCH_SIZE*PATCH_SIZE).max(-1)[0].view(-1,5,1,1)
        x = (x - min_values)/(max_values - min_values)

        return x

In [None]:
idx_map = torch.stack([torch.arange(PATCH_SIZE)]*PATCH_SIZE).to(device)
idx_map = torch.stack([idx_map,idx_map.T]).view(1,1,2,PATCH_SIZE,PATCH_SIZE)
class myLoss3(nn.Module):
    def __init__(
            self,
            alpha=.5
        ):
        super().__init__()
        self.alpha = alpha

    def clone(self):
        return myLoss3(self.alpha)

    def forward(
            self,
            y,# Predictions
            t # Targets
        ):
        mask_pred = y
        _,mask_true = t
#       The heatmap Loss as the distance between the predicted Normal and the ideal one
#       Let's define the ideal heatmaps as the Normal distributions
#       centered on the diagnostic centers with s2 = PATCH_SIZE/8
        s2 = s2 = torch.as_tensor([PATCH_SIZE/8]*5)
#       Then the corresponding alphas and normalization constants would be
        A = -1/(2*s2).to(device)
        K = 1/torch.sqrt(2*math.pi*s2).to(device)
#       Predicted heatmaps rescaling
        mask_pred = mask_pred*K.view(1,5,1,1)
#       Ideal heatmaps
        mask = idx_map - mask_true.view(-1,5,2,1,1)
        mask = torch.exp((A.view(-1,5,1,1,1)*mask*mask).sum(2))*K.view(-1,5,1,1)
#       Distance
        D = 1 - ((mask*mask_pred).sum())**2/((mask*mask).sum()*(mask_pred*mask_pred).sum())

        return D


In [None]:
import numpy as np

# Number of rows in the DataFrame
n = len(RF)

# Create an array of repeated values from 1 to 5, evenly distributed
values = np.tile(np.arange(1, 6), n // 5 + 1)[:n]  # Repeat 1-5, trim to match DataFrame size

# Assign the values to the 'series_description2' column
RF['series_description2'] = values

# Display the first few rows to check the result
RF.head()

In [None]:
SEG_TRAIN = True

In [None]:
import pickle

tdf = RF[RF['series_description2'] != fold]
vdf = RF[RF['series_description2'] == fold]

# tdf2=df_melted[df_melted['series_description'] != fold]
# vdf2=df_melted[df_melted['series_description'] == fold]

tds = T1Dataset(tdf)
vds = T1Dataset(vdf,VALID=True)
tdl = torch.utils.data.DataLoader(tds, batch_size=SEG['BS'], shuffle=True, drop_last=True)
vdl = torch.utils.data.DataLoader(vds, batch_size=SEG['BS'], shuffle=False)

if SEG_TRAIN:
    seed_everything(SEED)

    dls = DataLoaders(tdl,vdl)

    n_iter = len(tds)//SEG['BS']

    model = myUNet3()
    learn = Learner(
        dls,
        model,
        lr=SEG['LR'],
        loss_func=myLoss3(alpha=0.5),
        # cbs=[
        #     ShowGraphCallback(),
        #     alpha_cb
        # ]
    )
    learn.fit_one_cycle(SEG['EPOCHS'])
    with open('/content/drive/MyDrive/RSNA_csv/'+"SEG_"+"SCS_"+str(fold)+"_30"+".pkl", 'wb') as f:
      pickle.dump(model, f)
    del tdl,vdl,dls,model,learn
    gc.collect()

# saggitalT2画像の前処理

In [None]:
base_dir = '/content/train_images/'

# リストを初期化
data = []

# 一番上の階層（study_id）のフォルダをたどる
for study_id in os.listdir(base_dir):
    study_path = os.path.join(base_dir, study_id)
    if os.path.isdir(study_path):
        # 二番目の階層（series_id）のフォルダをたどる
        for series_id in os.listdir(study_path):
            series_path = os.path.join(study_path, series_id)
            if os.path.isdir(series_path):
                # 三番目の階層（〇〇.dcm ファイル）をたどる
                for filename in os.listdir(series_path):
                    if filename.endswith('.dcm'):
                        # 〇〇.dcm の〇〇部分（instance_number）を抽出
                        instance_number = filename.split('.')[0]
                        # データを追加
                        data.append([study_id, series_id, instance_number])

# pandas DataFrameを作成
df = pd.DataFrame(data, columns=['study_id', 'series_id', 'instance_number'])

In [None]:
df_meta_f = pd.read_csv('/content/train_series_descriptions.csv')
df_meta_f.tail()

df = df.astype('int64')

all_df = pd.merge(df, df_meta_f, on=['study_id', 'series_id'], how='inner')
s2_all_df=all_df[all_df['series_description']=='Sagittal T2/STIR']

In [None]:
import os
import numpy as np

# データフレームをコピーして操作
s2_all_df_copy = s2_all_df.copy()

# x_pos カラムを NaN で初期化
s2_all_df_copy['x_pos'] = None
s2_all_df_copy['y_pos'] = None
s2_all_df_copy['z_pos'] = None
s2_all_df_copy['pixel_sp_z'] = None
s2_all_df_copy['pixel_sp_y'] = None
for idx, row in s2_all_df_copy.iterrows():
    # DICOMファイルのパスを構築
    dicom_file_path = f"/content/train_images/{row['study_id']}/{row['series_id']}/{row['instance_number']}.dcm"

    # DICOMファイルを読み込む
    dicom_data = pydicom.dcmread(dicom_file_path)

    # 'Image Position (Patient)' の x 座標を取得して x_pos カラムに格納
    s2_all_df_copy.loc[idx, 'x_pos'] = dicom_data.ImagePositionPatient[0]
    s2_all_df_copy.loc[idx, 'y_pos'] = dicom_data.ImagePositionPatient[1]
    s2_all_df_copy.loc[idx, 'z_pos'] = dicom_data.ImagePositionPatient[2]
    s2_all_df_copy.loc[idx, 'pixel_sp_z']=dicom_data.PixelSpacing[0]
    s2_all_df_copy.loc[idx, 'pixel_sp_y']=dicom_data.PixelSpacing[1]


print(s2_all_df_copy)

In [None]:
import pandas as pd
# Convert 'x_pos' to numeric, setting errors='coerce' to handle any non-numeric values
s2_all_df_copy['x_pos'] = pd.to_numeric(s2_all_df_copy['x_pos'], errors='coerce')

# Group by 'study_id' and sort within each group by 'x_pos'
s2_all_df_sorted = s2_all_df_copy.groupby('study_id').apply(lambda x: x.sort_values('x_pos')).reset_index(drop=True)


In [None]:
import pandas as pd

# グループごとに均等な間隔で10個の行を選択する関数
def select_evenly_spaced(group, num=10):
    # グループのサイズが10個未満の場合はそのまま返す
    if len(group) <= num:
        return group
    # ステップを計算し、均等な間隔でデータを選ぶ
    step = len(group) / num
    indices = [int(i * step) for i in range(num)]
    return group.iloc[indices]

# study_id, instance_number, descriptionでグループ化し、各グループから均等に10個選択
new_df = s2_all_df_sorted.groupby(['study_id'], group_keys=False).apply(select_evenly_spaced)

# 結果の確認
new_df.head(20)

In [None]:
import pandas as pd

# s_all_df3の5つのコピーを作成し、それぞれにlevelカラムを追加
dfs = [new_df.assign(level=i) for i in range(5)]

# それらのデータフレームを縦に連結
combined_df = pd.concat(dfs, ignore_index=True)

# 結果の確認
print(combined_df.head())
print(combined_df['level'].value_counts())

In [None]:
combined_df = pd.merge(combined_df,train, on='study_id', how='left')

In [None]:
combined_df['spinal_canal_stenosis_l1_l2']

In [None]:
import pandas as pd
import numpy as np

# サンプルデータフレームの作成（ここでは仮のカラムを使用しています）
# 実際には combined_df には適切なカラムが含まれている必要があります
# combined_df = pd.DataFrame({
#     'description': [...],
#     'level': [...],
#     'left_neural_foraminal_narrowing_l1_l2': [...],
#     'left_neural_foraminal_narrowing_l2_l3': [...],
#     'left_neural_foraminal_narrowing_l3_l4': [...],
#     'left_neural_foraminal_narrowing_l4_l5': [...],
#     'left_neural_foraminal_narrowing_l5_s1': [...],
#     'right_neural_foraminal_narrowing_l1_l2': [...],
#     'right_neural_foraminal_narrowing_l2_l3': [...],
#     'right_neural_foraminal_narrowing_l3_l4': [...],
#     'right_neural_foraminal_narrowing_l4_l5': [...],
#     'right_neural_foraminal_narrowing_l5_s1': [...],
# })

# `label`カラムを作成
def get_label(row):
      if row['level'] == 0:
          return row['spinal_canal_stenosis_l1_l2']
      elif row['level'] == 1:
          return row['spinal_canal_stenosis_l2_l3']
      elif row['level'] == 2:
          return row['spinal_canal_stenosis_l3_l4']
      elif row['level'] == 3:
          return row['spinal_canal_stenosis_l4_l5']
      elif row['level'] == 4:
          return row['spinal_canal_stenosis_l5_s1']
      else:
        return np.nan  # 予期しない値の場合はNaNを返す

# 新しいカラム `label` を作成
combined_df['label'] = combined_df.apply(get_label, axis=1)

# 結果の確認
print(combined_df.head())


In [None]:
def pad_group(df, target_size=10):
    # If the group has fewer than the target_size rows, duplicate the first row
    if len(df) < target_size:
        first_row = df.iloc[0:1]  # Get the first row of the group as a DataFrame
        while len(df) < target_size:
            # Concatenate the first row to the original group DataFrame
            df = pd.concat([df, first_row], ignore_index=True)
    return df

# Assuming merged_df exists with 'study_id' and 'level' columns
# Group by 'study_id' and 'level'
grouped = combined_df.groupby(['study_id', 'level'])

# Create an empty list to hold the padded groups
padded_groups = []

# Iterate over each group
for _, group in grouped:
    # Apply padding function to ensure group has 10 rows
    padded_group = pad_group(group)
    padded_groups.append(padded_group)

# Concatenate all the padded groups back into a single DataFrame
padded_df = pd.concat(padded_groups, ignore_index=True)

In [None]:
# 被りなしで instance_number を f'instance_number_{i}' に格納する関数
def assign_unique_instance_numbers(group, max_instances=10):
    # 重複を排除
    unique_instance_numbers = group['instance_number'].drop_duplicates().reset_index(drop=True)

    # 新しいカラム用の辞書を作成
    new_row = {}
    for i in range(max_instances):
        if i < len(unique_instance_numbers):
            new_row[f'instance_number_{i}'] = unique_instance_numbers[i]
        else:
            new_row[f'instance_number_{i}'] = unique_instance_numbers[0]

    # groupの最初の行の情報を保持（他の列）
    first_row = group.iloc[0].to_dict()
    first_row.update(new_row)

    return pd.Series(first_row)

# study_id と level ごとにグループ化し、各グループに対して処理を適用
grouped_df_s2 = padded_df.groupby(['study_id', 'level']).apply(assign_unique_instance_numbers).reset_index(drop=True)

# 結果の確認

In [None]:
grouped_df_s2=grouped_df_s2[grouped_df_s2['label'].notna()]

In [None]:
grouped_df_s2.reset_index(drop=True,inplace=True)

In [None]:
grouped_df_s2.to_pickle('/content/drive/MyDrive/RSNA_csv/grouped_df_s2_rev.pkl')

# 画像認識モデル学習部分

In [None]:
grouped_df_s2 = pd.read_pickle('/content/drive/MyDrive/RSNA_csv/grouped_df_s2_rev.pkl')

In [None]:
import albumentations as A

AUG_PROB = 0.75
transforms_train = A.Compose([
    A.RandomBrightnessContrast(brightness_limit=(-0.2, 0.2), contrast_limit=(-0.2, 0.2), p=AUG_PROB),
    A.OneOf([
        A.MotionBlur(blur_limit=5),
        # A.MedianBlur(blur_limit=5),
        A.GaussianBlur(blur_limit=5),
        A.GaussNoise(var_limit=(5.0, 30.0)),
    ], p=AUG_PROB),

    A.OneOf([
        A.OpticalDistortion(distort_limit=1.0),
        A.GridDistortion(num_steps=5, distort_limit=1.),
        A.ElasticTransform(alpha=3),
    ], p=AUG_PROB),

    A.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.1, rotate_limit=15, border_mode=0, p=AUG_PROB),
  #  A.CoarseDropout(max_holes=15, max_height=8, max_width=8, min_holes=1, min_height=8, min_width=8, p=AUG_PROB),
])

In [None]:
patch_size=90

In [None]:
import torch.nn.functional as F
class ViT_T1_Dataset(Dataset):
    def __init__(self, df, UNet, VALID=False, P=patch_size, alpha=0,transform=None):
        self.data = df
        self.UNet = UNet
        self.VALID = VALID
        self.P = P
        self.alpha = alpha
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        x = np.zeros((10, self.P, self.P), dtype=np.float32)
        non_zero_slice=[]

        for i in range(10):

          sample = '/content/train_images/'
          sample = sample+str(self.data.iloc[index]['study_id'])+'/'+str(self.data.iloc[index]['series_id'])+'/'+str(self.data.iloc[index][f'instance_number_{i}'])+'.dcm'

          image = pydicom.dcmread(sample).pixel_array
          H,W = image.shape
          # centers = torch.as_tensor([x for x in row[coor]]).view(5,2).float()
          # By plane resizing I've been distorting the proportions
          if H > W:
              d = W
              h = (H - d)//2
              image = image[h:h+d]
              # centers[:,1] -= h
              H = W
          elif H < W:
              d = H
              w = (W - d)//2
              image = image[:,w:w+d]
              # centers[:,0] -= w
              W = H
          image = cv2.resize(image,(PATCH_SIZE,PATCH_SIZE))
          image = torch.as_tensor(image/(np.max(image))).unsqueeze(0).unsqueeze(0).float().to(device)

          # OUT = 0
          # with torch.no_grad():
                  # for rot in [0,1,2,3]:
                          # OUT += torch.rot90(self.UNet(torch.rot90(image,rot,dims=[-2, -1])),-rot,dims=[-2, -1])
          OUT=self.UNet(image)
          OUT = (OUT > TH)[0]
          c = (OUT.unsqueeze(1)*idx_map[0]).view(5,2,PATCH_SIZE*PATCH_SIZE).sum(-1)
          d = OUT.view(5,PATCH_SIZE*PATCH_SIZE).sum(-1)
          m = d > 0
          c[m] = (c[m]/(d[m]).unsqueeze(-1)).long()
          c[~m] = self.P # I have to find a better solution

          image_slices = []

          for xy in c:
            y_start = max(0, xy[1] - self.P // 2)
            y_end = min(512,xy[1] + self.P - self.P // 2)
            x_start = max(0, xy[0] - self.P // 2)
            x_end = min(512,xy[0] + self.P - self.P // 2)

    # スライスが有効なサイズを持つか確認
            if (y_end - y_start == self.P) and (x_end - x_start == self.P):
              slice_img = image[0, 0, y_start:y_end, x_start:x_end]
              non_zero_slice.append(slice_img)
              image_slices.append(slice_img)
            else:
              zero_slice = torch.zeros((self.P, self.P), device=image.device)
              image_slices.append(zero_slice)
              # print(f"Skipped slice due to incorrect size: {(y_end - y_start, x_end - x_start)}")
              # print(f"Slice coordinates: {(xy[1], xy[0])}")

# スライスがあればスタックする
          if image_slices:
            try:
              image = torch.stack(image_slices)
            except RuntimeError as e:
              print(f"Error: {e}")
          else:
            print("No valid slices available for stacking.")

          # if not self.VALID: image = augment_image(image,self.alpha)
          x[i,...]=image[self.data.iloc[index]['level']].cpu().numpy()

        for i in range(10):
          if (x[i,...].sum() == 0) and (len(non_zero_slice)>0):
            x[i,...] = non_zero_slice[0].cpu().numpy()
          else:
            pass
          #nothing
        if self.transform is not None:
          x = self.transform(image=x)['image']
        x=torch.as_tensor(x).float()
        x = F.interpolate(x.unsqueeze(0), size=(224, 224), mode='bilinear', align_corners=False).squeeze(0)
        label = torch.as_tensor(labels[self.data.iloc[index]['label']])

        return [x.to(device),m.to(device)],[label.to(device),m[self.data.iloc[index]['level']].to(device)]


In [None]:
def myLoss(preds,target):
    target,mask = target
    # target=target[mask]
    # preds = preds[mask.view(-1)]
    return nn.CrossEntropyLoss(weight=torch.as_tensor([1.,2.,4.]).to(device))(preds+1e-12,target)

In [None]:
pip install timm

In [None]:
import timm
class ViT(nn.Module):
    def __init__(self, num_classes):
        super(ViT, self).__init__()
        # ここにViTのモデルアーキテクチャを実装
        self.vit = timm.create_model('eva02_base_patch14_224', pretrained=True, num_classes=num_classes,in_chans=10)
        # self.vit.features.conv0=nn.Conv2d(5, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
        # self.new_model=nn.Sequential(*self.layer)
    def forward(self, x):
        # ここに順伝播の処理を実装
        x,mask=x
        return self.vit(x)

In [None]:
seed_everything(SEED)
with open('/content/drive/MyDrive/RSNA_csv/'+"SEG_"+str(fold)+".pkl", 'rb') as f:
  UNet=pickle.load(f)
tds = ViT_T1_Dataset(tdf2,UNet)
vds = ViT_T1_Dataset(vdf2,UNet,VALID=True)

In [None]:
unique_studies = grouped_df_s2['study_id'].unique()
study_mapping = {study: (i % 5) + 1 for i, study in enumerate(unique_studies)}

# Map the numbers back to the original DataFrame
grouped_df_s2['series_description2'] = grouped_df_s2['study_id'].map(study_mapping)


In [None]:
tdf2=grouped_df_s2[grouped_df_s2['series_description2'] != fold]
vdf2=grouped_df_s2[grouped_df_s2['series_description2'] == fold]

In [None]:
from fastai.callback.core import Callback

class SaveModelCallback(Callback):
    def __init__(self, every_epoch=False, path='models', fname='model',with_opt=False):
        self.every_epoch = every_epoch
        self.path = path
        self.fname = fname
        self.with_opt=with_opt

    def after_epoch(self):
        # エポックごとにモデルを保存する
        if self.every_epoch:
            self.learn.save(f'{self.path}/{self.fname}_ep_{self.epoch}')

save_model_cb = SaveModelCallback(every_epoch=True, path='/content/drive/MyDrive/RSNA_csv', fname=f"eva_sagt2_aug_p_90_ch_10_f_{fold}",with_opt=True)

In [None]:
if 1:
    seed_everything(SEED)
    with open('/content/drive/MyDrive/RSNA_csv/'+"SEG_"+"SCS_"+"1"+".pkl", 'rb') as f:
      UNet=pickle.load(f)
    tds = ViT_T1_Dataset(tdf2,UNet,transform=transforms_train)
    vds = ViT_T1_Dataset(vdf2,UNet,VALID=True)
    tdl = torch.utils.data.DataLoader(tds, batch_size=16, shuffle=True, drop_last=True)
    vdl = torch.utils.data.DataLoader(vds, batch_size=16, shuffle=False)

    dls = DataLoaders(tdl,vdl)

    n_iter = len(tds)//INF['BS']

    model = ViT(num_classes=3)
    model.to(device)
    learn = Learner(
        dls,
        model,
        lr=INF['LR'],
        loss_func=myLoss,
        cbs=[
            save_model_cb,
            GradientClip,
            ShowGraphCallback(),
            # alpha_cb
        ]
    )
    learn.fit_one_cycle(INF['EPOCHS'])
    # with open('/content/drive/MyDrive/RSNA_csv/'+"VIT_"+str(fold)+".pkl", 'wb') as f:
      # pickle.dump(model, f)