In [1]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
import torchvision.transforms as transforms
import torch.utils.data as data
import torchvision
from torch.autograd import Variable
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import OneHotEncoder, LabelEncoder
from sklearn.metrics import accuracy_score
import pickle
from sklearn.model_selection import KFold
from sklearn.model_selection import StratifiedKFold
from torch.utils.data import DataLoader

In [None]:
# functions
class Dataset_CRNN(data.Dataset):
    "Characterizes a dataset for PyTorch"
    def __init__(self, data_path, folders, labels, frames, transform=None):
        "Initialization"
        self.data_path = data_path
        self.labels = labels
        self.folders = folders
        self.transform = transform
        self.frames = frames

    def __len__(self):
        "Denotes the total number of samples"
        return len(self.folders)

    def read_images(self, path, selected_folder, use_transform):
        X = []
        for i in self.frames:
            image = Image.open(os.path.join(path, selected_folder, 'frame_{:06d}.jpg'.format(i)))

            if use_transform is not None:
                image = use_transform(image)

            X.append(image)
        X = torch.stack(X, dim=0)

        return X

    def __getitem__(self, index):
        "Generates one sample of data"
        # Select sample
        folder = self.folders[index]

        # Load data
        X = self.read_images(self.data_path, folder, self.transform)     # (input) spatial images
        y = torch.LongTensor([self.labels[index]])                  # (labels) LongTensor are for int64 instead of FloatTensor

        # print(X.shape)
        return X, y

class ResCNNEncoder(nn.Module):
    def __init__(self, fc_hidden1=512, fc_hidden2=512, drop_p=0.3, CNN_embed_dim=300):
        """Load the pretrained ResNet-18 and replace top fc layer."""
        super(ResCNNEncoder, self).__init__()

        self.fc_hidden1, self.fc_hidden2 = fc_hidden1, fc_hidden2
        self.drop_p = drop_p

        #CNN模型选择
        resnet = models.resnet18(pretrained=True)
        modules = list(resnet.children())[:-1]
        self.resnet = nn.Sequential(*modules)
        
        self.fc1 = nn.Linear(resnet.fc.in_features, fc_hidden1)
        self.bn1 = nn.BatchNorm1d(fc_hidden1, momentum=0.01)
        self.fc2 = nn.Linear(fc_hidden1, fc_hidden2)
        self.bn2 = nn.BatchNorm1d(fc_hidden2, momentum=0.01)
        self.fc3 = nn.Linear(fc_hidden2, CNN_embed_dim)

    def forward(self, x_3d):
        cnn_embed_seq = []
        for t in range(x_3d.size(1)):
            with torch.no_grad():
                x = self.resnet(x_3d[:, t, :, :, :])
                x = x.view(x.size(0), -1)

            x = self.bn1(self.fc1(x))
            x = F.relu(x)
            x = self.bn2(self.fc2(x))
            x = F.relu(x)
            x = F.dropout(x, p=self.drop_p, training=self.training)
            x = self.fc3(x)
            cnn_embed_seq.append(x)

        cnn_embed_seq = torch.stack(cnn_embed_seq, dim=0).transpose_(0, 1)
        return cnn_embed_seq

class DecoderRNN(nn.Module):
    def __init__(self, CNN_embed_dim=300, h_RNN_layers=3, h_RNN=256, h_FC_dim=128, drop_p=0.3, num_classes=50):
        super(DecoderRNN, self).__init__()

        self.RNN_input_size = CNN_embed_dim
        self.h_RNN_layers = h_RNN_layers   # RNN hidden layers
        self.h_RNN = h_RNN                 # RNN hidden nodes
        self.h_FC_dim = h_FC_dim
        self.drop_p = drop_p
        self.num_classes = num_classes

        self.LSTM = nn.LSTM(
            input_size=self.RNN_input_size,
            hidden_size=self.h_RNN,        
            num_layers=h_RNN_layers,       
            batch_first=True,       # input & output will has batch size as 1s dimension. e.g. (batch, time_step, input_size)
        )

        self.fc1 = nn.Linear(self.h_RNN, self.h_FC_dim)
        self.fc2 = nn.Linear(self.h_FC_dim, self.num_classes)

    def forward(self, x_RNN):
        
        self.LSTM.flatten_parameters()
        RNN_out, (h_n, h_c) = self.LSTM(x_RNN, None)  
        """ h_n shape (n_layers, batch, hidden_size), h_c shape (n_layers, batch, hidden_size) """ 
        """ None represents zero initial hidden state. RNN_out has shape=(batch, time_step, output_size) """

        # FC layers
        x = self.fc1(RNN_out[:, -1, :])   # choose RNN_out at the last time step
        x = F.relu(x)
        x = F.dropout(x, p=self.drop_p, training=self.training)
        x = self.fc2(x)

        return x

def labels2cat(label_encoder, actions):
    return label_encoder.transform(actions)

def CRNN_final_prediction(model, device, loader):
    cnn_encoder, rnn_decoder = model
    cnn_encoder.eval()
    rnn_decoder.eval()

    all_y_pred = []
    with torch.no_grad():
        for batch_idx, (X, y) in enumerate(tqdm(loader)):
            # distribute data to device
            X = X.to(device)
            output = rnn_decoder(cnn_encoder(X))
            y_pred = output.max(1, keepdim=True)[1]  # location of max log-probability as prediction
            all_y_pred.extend(y_pred.cpu().data.squeeze().numpy().tolist())

    return all_y_pred

def train(log_interval, model, device, train_loader, optimizer, epoch):
    """Training function for one epoch"""
    cnn_encoder, rnn_decoder = model
    cnn_encoder.train()
    rnn_decoder.train()

    losses = []
    scores = []
    N_count = 0  # Count of processed samples
    
    for batch_idx, (X, y) in enumerate(train_loader):
        X, y = X.to(device), y.to(device).view(-1, )
        N_count += X.size(0)

        optimizer.zero_grad()
        output = rnn_decoder(cnn_encoder(X))
        loss = F.cross_entropy(output, y)
        losses.append(loss.item())

        # Calculate accuracy
        y_pred = torch.max(output, 1)[1]
        step_score = accuracy_score(y.cpu().data.squeeze().numpy(), y_pred.cpu().data.squeeze().numpy())
        scores.append(step_score)

        loss.backward()
        optimizer.step()

        # Log training progress
        if (batch_idx + 1) % log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}, Accu: {:.2f}%'.format(
                epoch + 1, N_count, len(train_loader.dataset),
                100. * (batch_idx + 1) / len(train_loader), loss.item(), 100 * step_score))
    
    return losses, scores

def validation(model, device, optimizer, test_loader):
    """Validation function"""
    cnn_encoder, rnn_decoder = model
    cnn_encoder.eval()
    rnn_decoder.eval()

    test_loss = 0
    all_y = []
    all_y_pred = []
    
    with torch.no_grad():
        for X, y in test_loader:
            X, y = X.to(device), y.to(device).view(-1, )
            output = rnn_decoder(cnn_encoder(X))
            loss = F.cross_entropy(output, y, reduction='sum')
            test_loss += loss.item()
            y_pred = output.max(1, keepdim=True)[1]
            all_y.extend(y)
            all_y_pred.extend(y_pred)

    test_loss /= len(test_loader.dataset)
    all_y = torch.stack(all_y, dim=0)
    all_y_pred = torch.stack(all_y_pred, dim=0)
    test_score = accuracy_score(all_y.cpu().data.squeeze().numpy(), all_y_pred.cpu().data.squeeze().numpy())

    print('\nValid set ({:d} samples): Average loss: {:.4f}, Accuracy: {:.2f}%\n'.format(
        len(all_y), test_loss, 100* test_score))
    
    return test_loss, test_score

## ---------------------- end of CRNN module ---------------------- ##


In [None]:
# Data path configuration
data_path = "./WB_Set B_jpg"

# Define result saving paths
save_path = "./output__crnn"
save_model_path = os.path.join(save_path, "models")

os.makedirs(save_path, exist_ok=True)
os.makedirs(save_model_path, exist_ok=True)

# Model parameters
CNN_fc_hidden1, CNN_fc_hidden2 = 1024, 768
CNN_embed_dim = 512
res_size = 224
dropout_p = 0.3
RNN_hidden_layers = 3
RNN_hidden_nodes = 512
RNN_FC_dim = 256
k = 3  # Number of classes
epochs = 50
batch_size = 50
learning_rate = 1e-3
log_interval = 10  # Log training info every N batches
patience = 5  # Early stopping patience
no_improve_count = 0  # Counter for early stopping
begin_frame, end_frame, skip_frame = 1, 60, 1  # Frame selection parameters

# CUDA settings
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
params = {'batch_size': batch_size, 'shuffle': True, 'num_workers': 4, 'pin_memory': True} if use_cuda else {'batch_size': batch_size, 'shuffle': True}

# Action class names and label encoding
action_names = ['Plunging', 'Spilling', 'Surging']
le = LabelEncoder()
le.fit(action_names)
action_category = le.transform(action_names).reshape(-1, 1)
enc = OneHotEncoder()
enc.fit(action_category)

# Load all filenames and extract actions
fnames = os.listdir(data_path)
actions = []
all_names = []
for f in fnames:
    loc = f.find('_')
    if loc == -1:
        print(f"Unexpected file format: {f}")
        continue
    action = f[:loc]
    actions.append(action)
    all_names.append(f)

# Prepare data lists
all_X_list = all_names
all_y_list = labels2cat(le, actions)

# Setup stratified k-fold cross validation
skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)

# Image transformations
transform = transforms.Compose([
    transforms.Resize([res_size, res_size]),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

# Frame selection
selected_frames = np.arange(begin_frame, end_frame, skip_frame).tolist()

# Initialize best validation score
best_valid_score = -float('inf')
fold = 1  # Current fold counter

# File suffixes to check for completed folds
file_suffixes = [
    'epoch_training_losses.npy',
    'epoch_training_scores.npy',
    'epoch_validation_loss.npy',
    'epoch_validation_score.npy',
]

# Main training loop with k-fold cross validation
for train_index, valid_index in skf.split(all_X_list, all_y_list):
    # Check if current fold is already completed
    fold_completed = all(os.path.exists(os.path.join(save_path, f'CRNN_fold_{fold}_{suffix}')) for suffix in file_suffixes)
    
    if fold_completed:
        print(f"\nFold {fold} already completed. Skipping to next fold...\n")
        fold += 1
        continue
    
    print(f"Fold {fold}:")
    # Split data into training and validation sets
    train_list = [all_X_list[i] for i in train_index]
    valid_list = [all_X_list[i] for i in valid_index]
    train_label = [all_y_list[i] for i in train_index]
    valid_label = [all_y_list[i] for i in valid_index]

    # Create datasets and data loaders
    train_set = Dataset_CRNN(data_path, train_list, train_label, selected_frames, transform=transform)
    valid_set = Dataset_CRNN(data_path, valid_list, valid_label, selected_frames, transform=transform)
    train_loader = DataLoader(train_set, **params)
    valid_loader = DataLoader(valid_set, **params)

    # Initialize models
    cnn_encoder = ResCNNEncoder(fc_hidden1=CNN_fc_hidden1, fc_hidden2=CNN_fc_hidden2, 
                               drop_p=dropout_p, CNN_embed_dim=CNN_embed_dim).to(device)
    rnn_decoder = DecoderRNN(CNN_embed_dim=CNN_embed_dim, h_RNN_layers=RNN_hidden_layers, 
                            h_RNN=RNN_hidden_nodes, h_FC_dim=RNN_FC_dim, 
                            drop_p=dropout_p, num_classes=k).to(device)

    # Multi-GPU support if available
    if torch.cuda.device_count() > 1:
        print("Using", torch.cuda.device_count(), "GPUs!")
        cnn_encoder = nn.DataParallel(cnn_encoder)
        rnn_decoder = nn.DataParallel(rnn_decoder)
        crnn_params = (list(cnn_encoder.module.fc1.parameters()) + list(cnn_encoder.module.bn1.parameters()) +
                       list(cnn_encoder.module.fc2.parameters()) + list(cnn_encoder.module.bn2.parameters()) +
                       list(cnn_encoder.module.fc3.parameters()) + list(rnn_decoder.parameters()))
    else:
        print("Using", torch.cuda.device_count(), "GPU!")
        crnn_params = (list(cnn_encoder.fc1.parameters()) + list(cnn_encoder.bn1.parameters()) +
                       list(cnn_encoder.fc2.parameters()) + list(cnn_encoder.bn2.parameters()) +
                       list(cnn_encoder.fc3.parameters()) + list(rnn_decoder.parameters()))

    # Initialize optimizer
    optimizer = torch.optim.Adam(crnn_params, lr=learning_rate)

    # Initialize lists to track metrics
    epoch_train_losses = []
    epoch_train_scores = []
    epoch_valid_losses = []
    epoch_valid_scores = []

    # Early stopping counters
    no_improve_count = 0
    current_best_valid_score = -float('inf')

    # Training loop for current fold
    for epoch in range(epochs):
        train_losses, train_scores = train(log_interval, [cnn_encoder, rnn_decoder], device, train_loader, optimizer, epoch)
        epoch_valid_loss, epoch_valid_score = validation([cnn_encoder, rnn_decoder], device, optimizer, valid_loader)

        # Store metrics
        epoch_train_losses.append(train_losses)
        epoch_train_scores.append(train_scores)
        epoch_valid_losses.append(epoch_valid_loss)
        epoch_valid_scores.append(epoch_valid_score)

        # Check for improvement and save best model
        if epoch_valid_score > current_best_valid_score:
            current_best_valid_score = epoch_valid_score
            no_improve_count = 0
            # Save model state dicts
            torch.save(cnn_encoder.state_dict(), os.path.join(save_model_path, f'best_cnn_encoder_fold_{fold}.pth'))
            torch.save(rnn_decoder.state_dict(), os.path.join(save_model_path, f'best_rnn_decoder_fold_{fold}.pth'))
            torch.save(optimizer.state_dict(), os.path.join(save_model_path, f'best_optimizer_fold_{fold}.pth'))
            # Save full models
            torch.save(cnn_encoder, os.path.join(save_model_path, f'best_cnn_encoder_full_fold_{fold}.pth'))
            torch.save(rnn_decoder, os.path.join(save_model_path, f'best_rnn_decoder_full_fold_{fold}.pth'))
            print(f"Epoch {epoch + 1} in Fold {fold}: Best model saved with validation score {epoch_valid_score:.2f}")
        else:
            no_improve_count += 1
            print(f"Epoch {epoch + 1} in Fold {fold}: No improvement in validation score. Count={no_improve_count}")
            if no_improve_count > patience:
                print("Early stopping triggered for current fold.")
                break

        # Update global best validation score
        if current_best_valid_score > best_valid_score:
            best_valid_score = current_best_valid_score

    # Save training metrics for current fold
    A = np.array(epoch_train_losses)
    B = np.array(epoch_train_scores)
    C = np.array(epoch_valid_losses)
    D = np.array(epoch_valid_scores)

    np.save(os.path.join(save_path, f'CRNN_fold_{fold}_epoch_training_losses.npy'), A)
    np.save(os.path.join(save_path, f'CRNN_fold_{fold}_epoch_training_scores.npy'), B)
    np.save(os.path.join(save_path, f'CRNN_fold_{fold}_epoch_validation_loss.npy'), C)
    np.save(os.path.join(save_path, f'CRNN_fold_{fold}_epoch_validation_score.npy'), D)

    print(f"Fold {fold} training complete.")

    # Clean up memory
    del cnn_encoder
    del rnn_decoder
    torch.cuda.empty_cache()

    fold += 1  # Move to next fold

print(f"Cross-validation complete. Best validation score: {best_valid_score:.4f}")