In [1]:
import numpy as np                # 数组的数值计算库
import pandas as pd               # 数据处理和分析库
import matplotlib.pyplot as plt   # 数据可视化库
import torch                       # PyTorch深度学习库
import os                          # 与操作系统交互的库
import pytorch_lightning as pl     # 基于PyTorch的轻量级深度学习框架
from torch.utils.data import Dataset, DataLoader   # PyTorch数据加载和处理库
from sklearn import model_selection   # 机器学习模型选择和评估库
import torchvision.transforms as transforms   # PyTorch中的图像转换库
import torchvision.io    # PyTorch中加载和保存图像和视频文件的库
import librosa          # 音频处理和分析库
from PIL import Image   # Python中处理图像的库
import albumentations as alb   # 图像增强技术库
import torch.multiprocessing as mp   # PyTorch中的多进程数据加载库
import warnings          # Python中的警告处理库

warnings.filterwarnings('ignore')

In [2]:
import numpy as np                # 用于数组的数值计算
import librosa as lb              # 音频处理和分析库
import librosa.display as lbd     # 用于在matplotlib中显示音频处理的库
import soundfile as sf            # 读写音频文件的库
from soundfile import SoundFile   # 用于读取音频文件的类
import pandas as pd               # 数据处理和分析库
from IPython.display import Audio  # 在Jupyter Notebook中播放音频的库
from pathlib import Path          # 处理路径的库

from matplotlib import pyplot as plt   # 数据可视化库
from tqdm.notebook import tqdm          # 在Jupyter Notebook中显示进度条的库
import joblib, json, re                # 数据处理和序列化的库

from sklearn.model_selection import StratifiedKFold  # 用于交叉验证的类
tqdm.pandas()

from pytorch_lightning.callbacks import ModelCheckpoint, BackboneFinetuning, EarlyStopping #lightning中一些常用的回调函数

In [4]:
def compute_melspec(y, sr, n_mels, fmin, fmax):
    # 计算Mel频谱图
    melspec = lb.feature.melspectrogram(y=y, sr=sr, n_mels=n_mels, fmin=fmin, fmax=fmax,)
    # 将功率谱转换为分贝
    melspec = lb.power_to_db(melspec).astype(np.float32)
    return melspec

In [5]:
class Config:
    use_aug = False                # 是否使用数据增强
    num_classes = 360              # 分类的类别数
    batch_size = 64                # 批次大小
    epochs = 5                     # 训练轮数
    PRECISION = 16                 # 模型精度
    PATIENCE = 8                   # EarlyStopping策略的耐心程度
    seed = 2023                    # 随机数生成器的种子
    model = "tf_efficientnet_b0_ns"   # 所使用的模型架构
    pretrained = True              # 是否使用预训练权重
    weight_decay = 1e-3            # 权重衰减系数
    use_mixup = True               # 是否使用MixUp数据增强
    mixup_alpha = 0.2              # MixUp的超参数
    DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')   # 使用的设备（GPU或CPU）

    data_root = "/kaggle/input/birdclef-2023/"   # 数据文件的根目录
    train_images = "/kaggle/input/split-creating-melspecs-stage-1/specs/train/"   # 训练集数据路径
    valid_images = "/kaggle/input/split-creating-melspecs-stage-1/specs/valid/"   # 验证集数据路径
    train_path = "/kaggle/input/bc2023-train-val-df/train.csv"   # 训练集标签路径
    valid_path = "/kaggle/input/bc2023-train-val-df/valid.csv"   # 验证集标签路径
    
    SR = 32000                      # 音频的采样率
    DURATION = 5                    # 音频的时长
    MAX_READ_SAMPLES = 5            # 最多读取的音频样本数
    LR = 5e-4                       # 学习率
    
    sampling_rate = 32000           # 音频的采样率
    duration = 5                   # 音频的时长
    fmin = 0                        # Mel频谱图最小频率
    fmax = None                     # Mel频谱图最大频率
    audios_path = Path("/kaggle/input/birdclef-2023/train_audio")   # 音频文件存储的路径
    out_dir_train = Path("specs/train")   # 训练集Mel频谱图输出路径
    out_dir_valid = Path("specs/valid")   # 验证集Mel频谱图输出路径


In [6]:
#将单通道mel图convert成rgb
def mono_to_color(X, eps=1e-6, mean=None, std=None):
    mean = mean or X.mean()
    std = std or X.std()
    X = (X - mean) / (std + eps)
    
    _min, _max = X.min(), X.max()

    if (_max - _min) > eps:
        V = np.clip(X, _min, _max)
        V = 255 * (V - _min) / (_max - _min)
        V = V.astype(np.uint8)
    else:
        V = np.zeros_like(X, dtype=np.uint8)

    return V

#对信号裁剪和填充
def crop_or_pad(y, length, is_train=True, start=None):
    if len(y) < length:
        y = np.concatenate([y, np.zeros(length - len(y))])
        
        n_repeats = length // len(y)
        epsilon = length % len(y)
        
        y = np.concatenate([y]*n_repeats + [y[:epsilon]])
        
    elif len(y) > length:
        if not is_train:
            start = start or 0
        else:
            start = start or np.random.randint(len(y) - length)

        y = y[start:start + length]

    return y

In [7]:
class AudioToImage:
    def __init__(self, sr=Config.sampling_rate, n_mels=128, fmin=Config.fmin, fmax=Config.fmax, duration=Config.duration, step=None, res_type="kaiser_fast", resample=True, train = True):

        self.sr = sr
        self.n_mels = n_mels
        self.fmin = fmin
        self.fmax = fmax or self.sr//2

        self.duration = duration
        self.audio_length = self.duration*self.sr
        self.step = step or self.audio_length
        
        self.res_type = res_type
        self.resample = resample

        self.train = train
    def audio_to_image(self, audio):
        melspec = compute_melspec(audio, self.sr, self.n_mels, self.fmin, self.fmax ) 
        image = mono_to_color(melspec)
#         compute_melspec(y, sr, n_mels, fmin, fmax)
        return image

    def __call__(self, row, save=True):

      audio, orig_sr = sf.read(row.path, dtype="float32")

      if self.resample and orig_sr != self.sr:
        audio = lb.resample(audio, orig_sr, self.sr, res_type=self.res_type)
        
      audios = [audio[i:i+self.audio_length] for i in range(0, max(1, len(audio) - self.audio_length + 1), self.step)]
      audios[-1] = crop_or_pad(audios[-1] , length=self.audio_length)
      images = [self.audio_to_image(audio) for audio in audios]
      images = np.stack(images)
        
      if save:
        if self.train:
            path = Config.out_dir_train/f"{row.filename}.npy"
        else:
            path = Config.out_dir_valid/f"{row.filename}.npy"
            
        path.parent.mkdir(exist_ok=True, parents=True)
        np.save(str(path), images)
      else:
        return  row.filename, images
    def vismel(self,row,idx, save=True):
        audio, orig_sr = sf.read(row.path[idx], dtype="float32")
        if self.resample and orig_sr != self.sr:
            audio = lb.resample(audio, orig_sr, self.sr, res_type=self.res_type)

        audios = [audio[i:i+self.audio_length] for i in range(0, max(1, len(audio) - self.audio_length + 1), self.step)]
        audios[-1] = crop_or_pad(audios[-1] , length=self.audio_length)
        images = [self.audio_to_image(audio) for audio in audios]
        images = np.stack(images)
        return images
    def get_audio(self,row,idx):
        audio, orig_sr = sf.read(row.path[idx], dtype="float32")
        if self.resample and orig_sr != self.sr:
            audio = lb.resample(audio, orig_sr, self.sr, res_type=self.res_type)
        audios = [audio[i:i+self.audio_length] for i in range(0, max(1, len(audio) - self.audio_length + 1), self.step)]
        audios[-1] = crop_or_pad(audios[-1] , length=self.audio_length)

In [8]:
!pip install -q torchtoolbox timm


[0m

In [9]:
pl.seed_everything(Config.seed, workers=True)  # 使用PyTorch Lightning框架的seed_everything函数来设置随机种子，以确保实验的可重复性。

2023

In [10]:
def config_to_dict(cfg):
    return dict((name, getattr(cfg, name)) for name in dir(cfg) if not name.startswith('__'))

In [11]:
df_train = pd.read_csv(Config.train_path)
df_valid = pd.read_csv(Config.valid_path)
df_train.head()

Unnamed: 0,primary_label,secondary_labels,type,latitude,longitude,scientific_name,common_name,author,license,rating,url,filename,len_sec_labels,path,frames,sr,duration
0,yebapa1,[],['song'],-3.3923,36.7049,Apalis flavida,Yellow-breasted Apalis,isaac kilusu,Creative Commons Attribution-NonCommercial-Sha...,3.0,https://www.xeno-canto.org/422175,yebapa1/XC422175.ogg,0,/kaggle/input/birdclef-2023/train_audio/yebapa...,405504,32000,12.672
1,yebapa1,[],['song'],-0.6143,34.0906,Apalis flavida,Yellow-breasted Apalis,James Bradley,Creative Commons Attribution-NonCommercial-Sha...,3.0,https://www.xeno-canto.org/289562,yebapa1/XC289562.ogg,0,/kaggle/input/birdclef-2023/train_audio/yebapa...,796630,32000,24.894687
2,combuz1,[],['call'],51.8585,-8.2699,Buteo buteo,Common Buzzard,Irish Wildlife Sounds,Creative Commons Attribution-NonCommercial-Sha...,4.0,https://www.xeno-canto.org/626969,combuz1/XC626969.ogg,0,/kaggle/input/birdclef-2023/train_audio/combuz...,254112,32000,7.941
3,chibat1,['laudov1'],"['adult', 'sex uncertain', 'song']",-33.1465,26.4001,Batis molitor,Chinspot Batis,Lynette Rudman,Creative Commons Attribution-NonCommercial-Sha...,3.5,https://www.xeno-canto.org/664196,chibat1/XC664196.ogg,1,/kaggle/input/birdclef-2023/train_audio/chibat...,1040704,32000,32.522
4,carcha1,[],['song'],-34.011,18.8078,Cossypha caffra,Cape Robin-Chat,Shannon Ronaldson,Creative Commons Attribution-NonCommercial-Sha...,1.0,https://www.xeno-canto.org/322333,carcha1/XC322333.ogg,0,/kaggle/input/birdclef-2023/train_audio/carcha...,40124,32000,1.253875


In [12]:
typelist=[]
for types in df_train['type']:
    # print(eval(types))
    for type in eval(types):
        if type not in typelist:
            typelist.append(type)
# typelist #   catergories=360

In [13]:
# Config.num_classes = len(df_train.primary_label.unique())
Config.num_classes=len(typelist)

In [14]:
multi_label=torch.zeros(360,dtype=torch.float)

In [15]:
df_train = pd.concat([df_train, pd.get_dummies(df_train['primary_label'])], axis=1)
df_valid = pd.concat([df_valid, pd.get_dummies(df_valid['primary_label'])], axis=1)

## 一些数据增强策略

In [16]:
import albumentations as A
def get_train_transform():
    return A.Compose([
        A.HorizontalFlip(p=0.5),
        A.OneOf([
                A.Cutout(max_h_size=5, max_w_size=16),
                A.CoarseDropout(max_holes=4),
            ], p=0.5),
    ])

In [17]:
    def __init__(self, sr=Config.sampling_rate, n_mels=128, fmin=Config.fmin, fmax=Config.fmax, duration=Config.duration, step=None, res_type="kaiser_fast", resample=True, train = True):

        self.sr = sr
        self.n_mels = n_mels
        self.fmin = fmin
        self.fmax = fmax or self.sr//2

        self.duration = duration
        self.audio_length = self.duration*self.sr
        self.step = step or self.audio_length
        
        self.res_type = res_type
        self.resample = resample

        self.train = train

### **Utility**

In [18]:
# Generates random integer
def random_int(shape=[], minval=0, maxval=1):
    return tf.random.uniform(shape=shape, minval=minval, maxval=maxval, dtype=tf.int32)


# Generats random float
def random_float(shape=[], minval=0.0, maxval=1.0):
    rnd = tf.random.uniform(shape=shape, minval=minval, maxval=maxval, dtype=tf.float32)
    return rnd

# **some augmentation trick concerning audio**

In [19]:
# Import required packages
import tensorflow as tf

# Define a function to crop or pad audio data to a target length
@tf.function
def CropOrPad(audio, target_len, pad_mode='constant'):
    # Get the length of the input audio
    audio_len = tf.shape(audio)[0]
    # If the length of the input audio is smaller than the target length, randomly pad the audio
    if audio_len < target_len:
        # Calculate the offset between the input audio and the target length
        diff_len = (target_len - audio_len)
        # Select a random location for padding
        pad1 = random_int([], minval=0, maxval=diff_len)
        # Calculate the second padding value
        pad2 = diff_len - pad1
        pad_len = [pad1, pad2]
        # Apply padding to the audio data
        audio = tf.pad(audio, paddings=[pad_len], mode=pad_mode)
    # If the length of the input audio is larger than the target length, crop the audio
    elif audio_len > target_len:
        # Calculate the difference in length between the input audio and the target length
        diff_len = (audio_len - target_len)
        # Select a random location for cropping
        idx = tf.random.uniform([], 0, diff_len, dtype=tf.int32)
        # Crop the audio data
        audio = audio[idx: (idx + target_len)]
    # Reshape the audio data to the target length
    audio = tf.reshape(audio, [target_len])
    # Return the cropped or padded audio data
    return audio


# Randomly shift audio -> any sound at <t> time may get shifted to <t+shift> time
@tf.function
def TimeShift(audio, prob=0.5):
    # Randomly apply time shift with probability `prob`
    if random_float() < prob:
        # Calculate random shift value
        shift = random_int(shape=[], minval=0, maxval=tf.shape(audio)[0])
        # Randomly set the shift to be negative with 50% probability
        if random_float() < 0.5:
            shift = -shift
        # Roll the audio signal by the shift value
        audio = tf.roll(audio, shift, axis=0)
    return audio

# Apply random noise to audio data
@tf.function
def GaussianNoise(audio, std=[0.0025, 0.025], prob=0.5):
    # Select a random value of standard deviation for Gaussian noise within the given range
    std = random_float([], std[0], std[1])
    # Randomly apply Gaussian noise with probability `prob`
    if random_float() < prob:
        # Add random Gaussian noise to the audio signal
        GN = tf.keras.layers.GaussianNoise(stddev=std)
        audio = GN(audio, training=True) # training=False don't apply noise to data
    return audio

# Applies augmentation to Audio Signal
def AudioAug(audio):
    # Apply time shift and Gaussian noise to the audio signal
    audio = TimeShift(audio, prob=CFG.timeshift_prob)
    audio = GaussianNoise(audio, prob=CFG.gn_prob)
    return audio

# Standardize the audio
@tf.function
def Normalize(data, min_max=True):
    # Compute the mean and standard deviation of the data
    MEAN = tf.math.reduce_mean(data)
    STD = tf.math.reduce_std(data)
    # Standardize the data
    data = tf.math.divide_no_nan(data - MEAN, STD)
    # Normalize to [0, 1]
    if min_max:
        MIN = tf.math.reduce_min(data)
        MAX = tf.math.reduce_max(data)
        data = tf.math.divide_no_nan(data - MIN, MAX - MIN)
    return data

In [20]:
ass BirdDataset_formerge(torch.utils.data.Dataset):
    def __init__(self, df, multi_label=multi_label, sr=Config.SR, duration=Config.DURATION, augmentations=None, train=True, n_mels=128, fmin=Config.fmin, fmax=Config.fmax, step=None):
        # 初始化函数，设置一些参数
        self.n_mels = n_mels   # Mel频谱图的维度数
        self.df = df   # 数据集的DataFrame
        self.sr = sr   # 音频的采样率
        self.train = train   # 是否是训练集
        self.duration = duration   # 音频的时长
        self.augmentations = augmentations   # 数据增强方法
        self.labels = multi_label   # 标签是否为多标签
        self.fmin = fmin   # Mel频谱图最小频率
        self.fmax = fmax or self.sr // 2   # Mel频谱图最大频率
        self.duration = duration   # 音频的时长
        self.audio_length = self.duration * self.sr   # 音频的长度
        self.step = step or self.audio_length   # 分割音频的步长，默认为音频长度
        self.tf2torch = lambda x: torch.tensor(x.numpy())   # 转换函数
        if train:
            self.img_dir = Config.train_images   # 训练集Mel频谱图的路径
        else:
            self.img_dir = Config.valid_images   # 验证集Mel频谱图的路径

    def __len__(self):
        # 返回数据集的长度
        return len(self.df)

    @staticmethod
    def normalize(image):
        image = image / 255.0
        #image = torch.stack([image, image, image])
        return image

def __getitem__(self, idx):
    # 获取数据集中的一条数据
    row = self.df.iloc[idx]
    # 获取音频文件路径并读取音频文件
    audiopath=row.path
    audio, orig_sr = sf.read(audiopath, dtype="float32")
    # 将音频文件根据设定的长度和步长分割成多个子音频
    audios = [audio[i:i+self.audio_length] for i in range(0, max(1, len(audio) - self.audio_length + 1), self.step)]
    # 如果最后一个子音频的长度不足设定的长度，则用0进行填充
    audios[-1] = crop_or_pad(audios[-1] , length=self.audio_length)
    # 进行时间偏移增强
    audio=TimeShift(audios[0])
    # 将音频切分成三个部分，并将每个部分转换为Mel频谱图
    audio_1=self.tf2torch(audio[0*128*313:1*128*313]).reshape(128,313)
    audio_2=self.tf2torch(audio[1*128*313:2*128*313]).reshape(128,313)
    audio_3=self.tf2torch(audio[2*128*313:3*128*313]).reshape(128,313)
    # 获取对应的Mel频谱图
    impath = self.img_dir + f"{row.filename}.npy"
    image = np.load(str(impath))[:Config.MAX_READ_SAMPLES]
    
    ########## RANDOM SAMPLING ################
    # 如果是训练集，则从多张Mel频谱图中随机选择一张
    if self.train:
        image = image[np.random.choice(len(image))]
    # 如果是验证集，则只选择一张Mel频谱图
    else:
        image = image[0]
    #####################################################################
    
    # 将Mel频谱图转换为PyTorchtensor类型
    image = torch.tensor(image).float()
    # 进行数据增强
    if self.augmentations:
        image = self.augmentations(image.unsqueeze(0)).squeeze()
        
    # 将标签转换为独热编码
    self.tmplabel=self.labels.clone()
    types=eval(row[2])
    indexes=[]
    for typename in types:
        if typename not in typelist:
            indexes.append(0)
        else:
            indexes.append(typelist.index(typename))
    for index in indexes:
        self.tmplabel[index]=torch.tensor(1)
    # 将图像转换为RGB格式，并进行归一化
    image = torch.stack([image, image, image])
    image = self.normalize(image)
    # 将音频和图像融合在一起
    merged_audio=torch.stack([audio_1,audio_2,audio_3])
    merged_image=merged_audio+image
    # 返回融合后的数据和对应的标签
    return merged_image, self.tmplabel

In [24]:
def get_fold_dls(df_train, df_valid):
    # 创建训练集和验证集的数据集
    ds_train = BirdDataset_formerge(
        df_train, 
        sr = Config.SR,
        duration = Config.DURATION,
        augmentations = None,
        train = True
    )
    ds_val = BirdDataset_formerge(
        df_valid, 
        sr = Config.SR,
        duration = Config.DURATION,
        augmentations = None,
        train = False
    )
    # 创建训练集和验证集的数据加载器
    dl_train = DataLoader(ds_train, batch_size=Config.batch_size , shuffle=True, num_workers = 2)    
    dl_val = DataLoader(ds_val, batch_size=Config.batch_size, num_workers = 2)
    # 返回训练集和验证集的数据加载器以及对应的数据集
    return dl_train, dl_val, ds_train, ds_val

In [25]:
def show_batch(img_ds, num_items, num_rows, num_cols, predict_arr=None):
    # 创建一个大小为12x6的画布
    fig = plt.figure(figsize=(12, 6))    
    # 随机选择num_items个数据进行展示
    img_index = np.random.randint(0, len(img_ds)-1, num_items)
    # 遍历每个选择的数据，并将其展示在画布上
    for index, img_index in enumerate(img_index):  # list first 9 images
        # 获取数据和对应的标签
        img, lb = img_ds[img_index]        
        ax = fig.add_subplot(num_rows, num_cols, index + 1, xticks=[], yticks=[])
        # 将数据转换为numpy数组，并将通道维度放到最后一维，然后显示在画布上
        if isinstance(img, torch.Tensor):
            img = img.detach().numpy()
        if isinstance(img, np.ndarray):
            img = img.transpose(1, 2, 0)
            ax.imshow(img)        
        # 设置标题为"Spec"
        title = f"Spec"
        ax.set_title(title)

In [26]:
dl_train, dl_val, ds_train, ds_val = get_fold_dls(df_train, df_valid)

In [27]:
from torch.optim.lr_scheduler import CosineAnnealingLR, CosineAnnealingWarmRestarts, ReduceLROnPlateau, OneCycleLR

def get_optimizer(lr, params):
    # 创建Adam优化器，并设置学习率、权重衰减等参数
    model_optimizer = torch.optim.Adam(
            filter(lambda p: p.requires_grad, params), 
            lr=lr,
            weight_decay=Config.weight_decay
        )
    # 设置学习率调度器为CosineAnnealingWarmRestarts，并设置相关参数
    interval = "epoch"
    lr_scheduler = CosineAnnealingWarmRestarts(
                            model_optimizer, 
                            T_0=Config.epochs, 
                            T_mult=1, 
                            eta_min=1e-6, 
                            last_epoch=-1
                        )

    # 返回优化器和学习率调度器的参数
    return {
        "optimizer": model_optimizer, 
        "lr_scheduler": {
            "scheduler": lr_scheduler,
            "interval": interval,
            "monitor": "val_loss",
            "frequency": 1
        }
    }

In [28]:
from torchtoolbox.tools import mixup_data, mixup_criterion
import torch.nn as nn
from torch.nn.functional import cross_entropy
import torchmetrics
import timm

In [29]:
import sklearn.metrics

def padded_cmap(solution, submission, padding_factor=5):
    # 对solution和submission进行处理，增加padding_factor个全1的行
    solution = solution#.drop(['row_id'], axis=1, errors='ignore')
    submission = submission#.drop(['row_id'], axis=1, errors='ignore')
    new_rows = []
    for i in range(padding_factor):
        new_rows.append([1 for i in range(len(solution.columns))])
    new_rows = pd.DataFrame(new_rows)
    new_rows.columns = solution.columns
    padded_solution = pd.concat([solution, new_rows]).reset_index(drop=True).copy()
    padded_submission = pd.concat([submission, new_rows]).reset_index(drop=True).copy()
    # 计算padded_solution和padded_submission的平均精度得分
    score = sklearn.metrics.average_precision_score(
        padded_solution.values,
        padded_submission.values,
        average='macro',
    )
    return score

def map_score(solution, submission):
    # 对solution和submission进行处理，计算原始的平均精度得分
    solution = solution#.drop(['row_id'], axis=1, errors='ignore')
    submission = submission#.drop(['row_id'], axis=1, errors='ignore')
    score = sklearn.metrics.average_precision_score(
        solution.values,
        submission.values,
        average='micro',
    )
    return score

In [30]:
class BirdClefModel(pl.LightningModule):
    def __init__(self, model_name=Config.model, num_classes = Config.num_classes, pretrained = Config.pretrained):
        super().__init__()
        self.num_classes = num_classes
        self.backbone = timm.create_model(model_name, pretrained=pretrained)

        if 'res' in model_name:
            self.in_features = self.backbone.fc.in_features
            self.backbone.fc = nn.Linear(self.in_features, num_classes)
        elif 'dense' in model_name:
            self.in_features = self.backbone.classifier.in_features
            self.backbone.classifier = nn.Linear(self.in_features, num_classes)
        elif 'efficientnet' in model_name:
            self.in_features = self.backbone.classifier.in_features
            self.backbone.classifier = nn.Sequential(
                nn.Linear(self.in_features, num_classes)
            )
#         print(num_classes)
        self.loss_function = nn.BCEWithLogitsLoss() 

    def forward(self,images):
        logits = self.backbone(images)
#         print(logits.shape)
        return logits
        
    def configure_optimizers(self):
        return get_optimizer(lr=Config.LR, params=self.parameters())

    def train_with_mixup(self, X, y):
        X, y_a, y_b, lam = mixup_data(X, y, alpha=Config.mixup_alpha)
        y_pred = self(X)
        loss_mixup = mixup_criterion(cross_entropy, y_pred, y_a, y_b, lam)
        return loss_mixup

    def training_step(self, batch, batch_idx):
        image, target = batch
#         print(Config.use_mixup)
        if Config.use_mixup:
            loss = self.train_with_mixup(image, target)
        else:
            y_pred = self(image)
            loss = self.loss_function(y_pred,target)

        self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True)
        return loss        

    def validation_step(self, batch, batch_idx):
        image, target = batch     
        y_pred = self(image)
        val_loss = self.loss_function(y_pred, target)
        self.log("val_loss", val_loss, on_step=True, on_epoch=True, logger=True, prog_bar=True)
        
        return {"val_loss": val_loss, "logits": y_pred, "targets": target}
    
    def train_dataloader(self):
        return self._train_dataloader 
    
    def validation_dataloader(self):
        return self._validation_dataloader
    
    def validation_epoch_end(self,outputs):
        
        
        return {'val_loss': avg_loss,'val_cmap':0}
    
    
    
    

In [31]:
from pytorch_lightning.loggers import WandbLogger
import gc

def run_training():
    print(f"Running training...")
    logger = None
    
    # 获得训练集和验证集的dataloader以及dataset
    dl_train, dl_val, ds_train, ds_val = get_fold_dls(df_train, df_valid)
    
    # 构建模型
    audio_model = BirdClefModel()

    # 设置早停和保存模型的回调函数
    early_stop_callback = EarlyStopping(monitor="val_loss", min_delta=0.00, patience=Config.PATIENCE, verbose= True, mode="min")
    checkpoint_callback = ModelCheckpoint(monitor='val_loss',
                                          dirpath= "/kaggle/working/exp1/",
                                      save_top_k=1,
                                      save_last= True,
                                      save_weights_only=True,
                                      filename= f'./{Config.model}_loss',
                                      verbose= True,
                                      mode='min')
    
    callbacks_to_use = [checkpoint_callback,early_stop_callback]

    # 构建Trainer对象，并设置相关参数
    trainer = pl.Trainer(
        gpus=1,
        val_check_interval=0.5,
        deterministic=True,
        max_epochs=Config.epochs,
        logger=logger,
        auto_lr_find=False,    
        callbacks=callbacks_to_use,
        precision=Config.PRECISION, accelerator="gpu" 
    )

    # 调用trainer.fit方法进行训练
    print("Running trainer.fit")
    trainer.fit(audio_model, train_dataloaders = dl_train, val_dataloaders = dl_val)                

    # 回收内存和清空GPU缓存
    gc.collect()
    torch.cuda.empty_cache()

In [None]:
run_training()