# 导入相应库

In [1]:
import os
import sys

import cv2
import numpy as np
import pandas as pd 
from PIL import Image
from matplotlib import pyplot as plt
import seaborn as sns
import time
import random
import shutil 

import scipy as sp
from sklearn.model_selection import StratifiedKFold,GroupKFold,KFold # 交叉验证
from tqdm.auto import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam,SGD,AdamW

import torchvision.models as models 
from torch.utils.data import DataLoader, Dataset
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts, CosineAnnealingLR, ReduceLROnPlateau
import albumentations as A
from albumentations.pytorch import ToTensorV2
import timm

#import loss_func
from torch.cuda.amp import autocast, GradScaler
import warnings
warnings.filterwarnings('ignore')

# 基本配置

In [2]:
class CFG:

    apex = True
    debug = False
    print_freq = 200
    num_workers = 2
    model_name = 'tf_efficientnet_b2'
    size_w = 224
    size_h = 224  #819
    scheduler = 'CosineAnnealingLR'  # ['ReduceLROnPlateau', 'CosineAnnealingLR', 'CosineAnnealingWarmRestarts']
    epochs = 5
    # factor=0.2 # ReduceLROnPlateau
    # patience=4 # ReduceLROnPlateau
    # eps=1e-6 # ReduceLROnPlateau
    T_max = 5  # CosineAnnealingLR8
    T_0=5   # CosineAnnealingWarmRestarts
    lr = 3e-4
    min_lr = 1e-6
    batch_size = 64
    weight_decay = 1e-6
    gradient_accumulation_steps = 1
    max_grad_norm = 1000
    seed = 42
    target_col = 'target'
    n_fold = 5
    trn_fold = [0,1,2,3,4]
    train = True

In [3]:
train = pd.read_csv("./Data/BoolArt/train.csv")
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
OUTPUT_DIR = './'
if not os.path.exists(OUTPUT_DIR):
    os.makedirs(OUTPUT_DIR)
    
if CFG.debug:
    CFG.epochs = 3
    train = train.sample(n=1000, random_state=CFG.seed).reset_index(drop=True)

def seed_torch(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    
seed_torch(seed=CFG.seed)


### 四则交叉验证

In [4]:
Fold = StratifiedKFold(n_splits=CFG.n_fold, shuffle=True, random_state=CFG.seed)
#labellll = np.array(train[CFG.target_col]).astype(np.int)
for n, (train_index, val_index) in enumerate(Fold.split(
                                                train, np.array(train[CFG.target_col]))):
    train.loc[val_index, 'fold'] = int(n)
train['fold'] = train['fold'].astype(int)

#### 查看数据集

In [9]:
train

Unnamed: 0,id,target,fold
0,15970,0,0
1,59263,4,3
2,21379,3,4
3,1855,0,4
4,30805,0,2
...,...,...,...
35546,17036,1,4
35547,6461,11,0
35548,18842,0,2
35549,46694,8,3


In [14]:
train['id']

0        15970
1        59263
2        21379
3         1855
4        30805
         ...  
35546    17036
35547     6461
35548    18842
35549    46694
35550    51623
Name: id, Length: 35551, dtype: int64

### 加载训练和验证数据

In [10]:
# ====================================================
# Dataset 
# ====================================================
class TrainDataset(Dataset):
    def __init__(self,df,transform=None):
        self.df = df
        self.file_names = df['id'].values
        self.labels = df[CFG.target_col].values
        self.transform = transform
        
    def __len__(self):
        return len(self.df)
    
    # 读取图片
    def __getitem__(self,idx): # 这里的idx如何读取呢？
        self.file_path = f'./Data/BoolArt/train_image/{self.file_names[idx]}.jpg' # 读取图片地址
        image = np.narray(Image.open(self.file_path).convert("RGB"))
        if self.transform:
            image = self.transform(image=image)['image']
        else:
            image = cv2.resize(image,(CFG.size_h,CFG.size_w)) # 和原码不一样
            image = image[np.newaxis,:,:]
            image = torch.form_numpy(image).float()
        
            

In [15]:
traindataset = TrainDataset(train)

In [17]:
file_names = train['id'].values

In [18]:
file_path = f'../input/boolart-image-classification/train_image/{file_names[idx]}.jpg'

NameError: name 'idx' is not defined