In [1]:
import warnings

import numpy as np
import torch
from torch import device, cuda, autocast
from torch.cuda.amp import GradScaler
from torch.nn import BCEWithLogitsLoss
from torch.utils.data import DataLoader, TensorDataset
from tqdm import tqdm

import wandb
from conv_models import AlternativeCNN
from from_video_to_training_batched_funcs import get_files_from_directory, select_random_videos, paths_to_labels, load_custom_sequences
from logs_to_wandb import log_original_to_wandb

warnings.filterwarnings(
    'ignore',
    message='invalid value encountered in cast',
    category=RuntimeWarning,
    module='wandb.sdk.data_types.image'
)

device_type = "cuda" if cuda.is_available() else "cpu"
DEVICE = device(device_type)
torch.manual_seed(42)
batch_size = 100
last_good_frame = 2

TRAINING_DATA_DIR = "easy_videos"
TESTING_DATA_DIR = "easyval_videos"

debugging = True
debug_length = 100
wandb_ = True
wandb_images_every = 10

num_epochs = 1

  _C._set_default_tensor_type(t)


In [2]:
all_videos = get_files_from_directory(TRAINING_DATA_DIR)
all_validation_videos = get_files_from_directory(TESTING_DATA_DIR)

In [3]:
model = AlternativeCNN().to(DEVICE)
lr = .001
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
scaler = GradScaler()

# Initialize the loss function
criterion = BCEWithLogitsLoss()

In [None]:
model.train()

iterations = debug_length if debugging else len(all_videos) // batch_size
# Start wandb run
if wandb_:
    wandb.init(project='binary_classification')
    data_table = wandb.Table(columns=["Predictions", "True Labels"])

# Training loop
for epoch in range(num_epochs):
    already_selected = []
    running_loss = 0.0
    correct = 0
    total = 0
    for i in tqdm(range(iterations)):
        batch_files, already_selected = select_random_videos(
            all_videos, batch_size, already_selected
        )
        labels = paths_to_labels(batch_files)
        batch_sequences = load_custom_sequences(batch_files)
        batch_sequences = batch_sequences[:, -last_good_frame, :, :] 
        if wandb_ and i % wandb_images_every == 0:
            log_original_to_wandb(batch_sequences[0], batch_files[0])
        
        # Add channel dimension and normalize
        batch_sequences = np.divide(batch_sequences[:, None, :, :], 256)

        # Ensure the data is in tensor form and on the correct device
        labels_tensor = torch.tensor(labels, dtype=torch.float).to(DEVICE)
        batch_sequences_tensor = torch.tensor(batch_sequences, dtype=torch.float).to(DEVICE)

        # Create the dataset and dataloader
        dataset = TensorDataset(batch_sequences_tensor, labels_tensor)
        dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
        
        for inputs, batch_labels in dataloader:
            inputs.to(DEVICE)
            batch_labels.to(DEVICE)
            
            # Zero the parameter gradients
            optimizer.zero_grad()
    
            # Forward pass
            with autocast(device_type):
                outputs = model(inputs)
                loss = criterion(outputs.squeeze(), batch_labels)
    
            # Backward pass and optimize
                loss.backward()
                optimizer.step()
    
                # Calculate the running loss and accuracy
                running_loss += loss.item() * inputs.size(0)
                predictions = torch.sigmoid(outputs).squeeze() > 0.5
                total += batch_labels.size(0)
                correct += (predictions == batch_labels).sum().item()
                
                if wandb_:
                    for pred, label in zip(outputs, labels):
                        data_table.add_data(pred, label)

        # Log metrics to wandb
        if wandb_:
            wandb.log({
                'epoch': epoch,
                'loss': running_loss / total,
                'accuracy': correct / total,
            })
    wandb.log({"predictions_vs_labels": data_table})

    print(f'Epoch {epoch+1}, Loss: {running_loss / total}, Accuracy: {correct / total}')

VBox(children=(Label(value='0.006 MB of 0.006 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.01111234893333454, max=1.0)…

 31%|███       | 31/100 [03:45<07:57,  6.92s/it]