In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.models as models
from torch.utils.data import DataLoader, TensorDataset
from dataset import DatasetCIFAR10
import wandb

### Parameters.

In [None]:
# Classifier parameters.
CLASSIFIER_NUMBER_OF_CLASSES = 10
CLASSIFIER_NUMBER_OF_EPOCHS = 50
CLASSIFIER_LEARNING_RATE = 0.01
CLASSIFIER_BATCH_SIZE = 64

# Parameters for both agents.

REPLAY_BUFFER_SIZE = 5e4
PRIOROTIZED_REPLAY_EXPONENT = 3

BATCH_SIZE = 32
LEARNING_RATE = 1e-3
TARGET_COPY_FACTOR = 0.01
BIAS_INITIALIZATION = 0

NUMBER_OF_STATE_DATA = 1000
TRAIN_DATASET_LENGTH = 5000

# BatchAgent's parameters.

DIRNAME = './batch_agent/' # The resulting batch_agent of this experiment will be written in a file.

WARM_START_EPISODES_BATCH_AGENT = 5
NN_UPDATES_PER_EPOCHS_BATCH_AGENT = 50

TRAINING_EPOCHS_BATCH_AGENT = 5
TRAINING_EPISODES_PER_EPOCH_BATCH_AGENT = 5

In [None]:
import os
import shutil

cwd = os.getcwd()

# Delete following directories if they exist.
for directory in [cwd+'/__pycache__', cwd+'/wandb', cwd+'/batch_agent', cwd+'/libact', cwd+'/AL_results', cwd+'/checkpoints', cwd+'/summaries']:
    if os.path.exists(directory):
        shutil.rmtree(directory, ignore_errors=True)

In [None]:
# Load the dataset.
dataset = DatasetCIFAR10(number_of_state_data=NUMBER_OF_STATE_DATA, train_dataset_length=TRAIN_DATASET_LENGTH)
train_loader = DataLoader(TensorDataset(torch.tensor(dataset.train_data).float(), torch.tensor(dataset.train_labels).long()), batch_size=CLASSIFIER_BATCH_SIZE, shuffle=True)
test_loader = DataLoader(TensorDataset(torch.tensor(dataset.test_data).float(), torch.tensor(dataset.test_labels).long()), batch_size=CLASSIFIER_BATCH_SIZE, shuffle=False)
print("Train data are {}.".format(len(dataset.train_data)))
print("State data are {}.".format(len(dataset.state_data)))
print("Test data are {}.".format(len(dataset.test_data)))

In [None]:
# Initialize the model.
class CNNClassifier(nn.Module):
    def __init__(self):
        super(CNNClassifier, self).__init__()
        self.resnet18 = models.resnet18(pretrained=True)
        for param in self.resnet18.parameters():
            param.requires_grad = False
        num_ftrs = self.resnet18.fc.in_features
        self.resnet18.fc = nn.Linear(num_ftrs, 10)

    def forward(self, x):
        x = x.reshape(-1, 3, 32, 32)
        return self.resnet18(x)

# Initialize the model and device.
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
classifier = CNNClassifier()
classifier.to(device)

# Define the loss function and optimizer.
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(classifier.parameters(), lr=0.01)

# Train the model
for epoch in range(10):
    for batch in train_loader:
        inputs, labels = batch
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = classifier(inputs)
        _, predicted = torch.max(outputs, 1)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
    print(f'Epoch {epoch+1}, Loss: {loss.item()}')

# Evaluate the model on the test data
test_loss = 0
true_positives = 0
false_positives = 0
with torch.no_grad():
    for batch in test_loader:
        inputs, labels = batch
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = classifier(inputs)
        _, predicted = torch.max(outputs, 1)
        loss = criterion(outputs, labels)
        test_loss += loss.item()
        true_positive = (predicted == labels).sum().item()
        true_positives += true_positive
        false_positive = (predicted == labels).sum().item() - true_positive
        false_positives += false_positive

TARGET_PRECISION = true_positives / (true_positives + false_positives)

print(f'Test Loss: {test_loss / len(test_loader)}')
print(f'Test Precision: {TARGET_PRECISION:.2f}')

In [None]:
from batch_envs import LalEnvFirstAccuracy
batch_env = LalEnvFirstAccuracy(dataset, classifier, epochs=CLASSIFIER_NUMBER_OF_EPOCHS, classifier_batch_size=CLASSIFIER_BATCH_SIZE, target_precision=TARGET_PRECISION)

In [None]:
from batch_helpers import ReplayBuffer
replay_buffer = ReplayBuffer(buffer_size=REPLAY_BUFFER_SIZE, prior_exp=PRIOROTIZED_REPLAY_EXPONENT)

In [None]:
# WARM-START EPISODES.

import torch
import numpy as np

# Initialize the variables.
episode_durations = []
episode_scores = []
episode_number = 1
episode_losses = []
episode_precisions = []
batches = []

# Warm start procedure.
for _ in range(WARM_START_EPISODES_BATCH_AGENT):
    print("Episode {}.".format(episode_number))
    # Reset the environment to start a new episode.
    state, next_action, indicies_unknown, reward = batch_env.reset(isBatchAgent=False, target_budget=1.0)
    done = False
    episode_duration = CLASSIFIER_NUMBER_OF_CLASSES

    # Before we reach a terminal state, make steps.
    while not done:
        # Choose a random action.
        batch = torch.randint(0, batch_env.n_actions, (1,)).item()
        batches.append(batch)

        # Get the numbers from 0 to n_actions.
        input_numbers = range(0, batch_env.n_actions)

        # Non-repeating using sample() function.
        batch_actions_indices = torch.tensor(np.random.choice(input_numbers, batch))
        action = batch
        print("- Step.")
        next_state, next_action, indicies_unknown, reward, done = batch_env.step(batch_actions_indices)

        if next_action == []:
            next_action.append(torch.tensor([0]))

        # Store the transition in the replay buffer.
        print("- Buffer.")
        replay_buffer.store_transition(state, action, reward, next_state, next_action, done)

        # Get ready for the next step.
        state = next_state
        episode_duration += batch

        done = True

    # Calculate the final accuracy and precision of the episode.
    episode_final_acc = batch_env.return_episode_qualities()     
    episode_scores.append(episode_final_acc[-1])
    episode_final_precision = batch_env.return_episode_precisions()     
    episode_precisions.append(episode_final_precision[-1])    
    episode_durations.append(episode_duration)  
    episode_number += 1

# Compute the average episode duration of episodes generated during the warm start procedure.
av_episode_duration = np.mean(episode_durations)
BIAS_INITIALIZATION = -av_episode_duration / 2

In [None]:
import torch
import matplotlib.pyplot as plt
from scipy.interpolate import interp1d

# Plot total budget size per episode.
xpoints = torch.tensor(range(0, len(episode_durations)))
ypoints = torch.tensor(episode_durations)
plt.figure(figsize=(20,10))
plt.subplot(3, 1, 1)
plt.plot(xpoints, ypoints, 'o', color='m')  # Plot points as blue circles.
xnew = torch.linspace(xpoints.min(), xpoints.max(), 500)
spl = interp1d(xpoints, ypoints, kind='cubic')
power_smooth = spl(xnew)
plt.plot(xnew, power_smooth, color='m')
plot_label = "Budget per episode. *Size of unlabeled data: " + str(len(dataset.train_data))
plt.title(plot_label, loc='left')
plt.xlabel("Episodes")
plt.ylabel("Budget size (percentage of the UD)")

# Plot total budget size (percentage of the UD) per episode.
xpoints = torch.tensor(range(0, len(episode_durations)))
ypoints = torch.tensor([x/len(dataset.train_data) for x in episode_durations])
plt.figure(figsize=(20,10))
plt.subplot(3, 1, 2)
plt.plot(xpoints, ypoints, 'o', color='k')  # Plot points as blue circles.
xnew = torch.linspace(xpoints.min(), xpoints.max(), 500)
spl = interp1d(xpoints, ypoints, kind='cubic')
power_smooth = spl(xnew)
plt.plot(xnew, power_smooth, color='k')
plot_label = "Budget per episode. *Size of unlabeled data: " + str(len(dataset.train_data))
plt.title(plot_label, loc='left')
plt.xlabel("Episodes")
plt.ylabel("Budget size (percentage of the UD)")

# Plot final achieved accuracy per episode.
xpoints = torch.tensor(range(0, len(episode_scores)))
ypoints = torch.tensor(episode_scores)
plt.figure(figsize=(20,10))
plt.subplot(3, 1, 3)
plt.plot(xpoints, ypoints, 'o', color='c')  # Plot points as blue circles.
xnew = torch.linspace(xpoints.min(), xpoints.max(), 500)
spl = interp1d(xpoints, ypoints, kind='cubic')
power_smooth = spl(xnew)
plt.plot(xnew, power_smooth, color='c')
plt.title("Final achieved accuracy per episode", loc='left')
plt.xlabel("Episodes")
plt.ylabel("ACC")
legend_label = "Maximum ACC: " + str(max(episode_scores))[:4]
plt.legend([legend_label])

# Plot final achieved precision per episode.
xpoints = torch.tensor(range(0, len(episode_precisions)))
ypoints = torch.tensor(episode_precisions)
plt.figure(figsize=(20,10))
plt.subplot(3, 1, 3)
plt.plot(xpoints, ypoints, 'o', color='y')  # Plot points as blue circles.
xnew = torch.linspace(xpoints.min(), xpoints.max(), 500)
spl = interp1d(xpoints, ypoints, kind='cubic')
power_smooth = spl(xnew)
plt.plot(xnew, power_smooth, color='y')
plt.title("Final achieved precision per episode", loc='left')
plt.xlabel("Episodes")
plt.ylabel("Precision")
legend_label = "Maximum precision: " + str(max(episode_precisions))[:4]
plt.legend([legend_label])

plt.show()

In [None]:
import torch

# Convert the list to a PyTorch tensor.
episode_precisions = torch.tensor(episode_precisions)
max_precision = torch.max(episode_precisions)

# Initialize an empty list to store the warm start batches.
warm_start_batches = []

# Iterate over the episode precisions and durations.
i = 0
for precision in episode_precisions:
    # Check if the precision is greater than or equal to the maximum precision.
    if precision >= max_precision:
        # Add the corresponding episode duration to the warm start batches list.
        warm_start_batches.append(episode_durations[i])
    i += 1

# Calculate the target budget
TARGET_BUDGET = torch.min(torch.tensor(warm_start_batches)) / len(dataset.train_data)
print("Target budget is {}.".format(TARGET_BUDGET))

In [None]:
import torch

# Define the train dataset length.
TRAIN_DATASET_LENGTH = 50000

# Create a DatasetCIFAR10 instance with the specified number of state data and train dataset length.
dataset = DatasetCIFAR10(number_of_state_data=NUMBER_OF_STATE_DATA, train_dataset_length=torch.tensor(TRAIN_DATASET_LENGTH).long())
print("Train data are {}.".format(len(dataset.train_data)))
print("State data are {}.".format(len(dataset.state_data)))
print("Test data are {}.".format(len(dataset.test_data)))

# Create a LalEnvFirstAccuracy instance with the dataset, classifier, and specified epochs and batch size.
batch_env = LalEnvFirstAccuracy(dataset, classifier, epochs=CLASSIFIER_NUMBER_OF_EPOCHS, classifier_batch_size=CLASSIFIER_BATCH_SIZE, target_precision=torch.tensor(TARGET_PRECISION).float())

In [None]:
from batch_dqn import DQN
batch_agent = DQN(
            observation_length=NUMBER_OF_STATE_DATA,
            learning_rate=LEARNING_RATE,
            batch_size=BATCH_SIZE,
            target_copy_factor=TARGET_COPY_FACTOR,
            bias_average=BIAS_INITIALIZATION,
           )

In [None]:
for update in range(NN_UPDATES_PER_EPOCHS_BATCH_AGENT):
    print("Update:", update+1)
    minibatch = replay_buffer.sample_minibatch(BATCH_SIZE)
    td_error = batch_agent.train(minibatch)
    replay_buffer.update_td_errors(td_error, minibatch.indices)

In [None]:
# BATCH-AGENT TRAINING.

# Initialize the agent.
agent_epoch_durations = []
agent_epoch_scores = []
agent_epoch_precisions = []

for epoch in range(TRAINING_EPOCHS_BATCH_AGENT):
    print("Training epoch {}.".format(epoch+1))

    # Simulate training episodes.
    agent_episode_durations = []
    agent_episode_scores = []
    agent_episode_precisions = []

    for training_episode in range(TRAINING_EPISODES_PER_EPOCH_BATCH_AGENT):
        # Reset the environment to start a new episode.
        state, action_batch, action_unlabeled_data, reward = batch_env.reset(isBatchAgent=True, target_budget=TARGET_BUDGET)
        done = False
        episode_duration = CLASSIFIER_NUMBER_OF_CLASSES
        first_batch = True

        # Run an episode.
        while not done:
            if first_batch:
                next_batch = action_batch
                next_unlabeled_data = action_unlabeled_data
                first_batch = False
            else:
                next_batch = next_action_batch_size
                next_unlabeled_data = next_action_unlabeled_data

            selected_batch, selected_indices = batch_agent.get_action(dataset=dataset, model=classifier, state=state, next_action_batch=next_batch, next_action_unlabeled_data=next_unlabeled_data)
            next_state, next_action_batch_size, next_action_unlabeled_data, reward, done = batch_env.step(selected_indices)
            if next_action_batch_size==[]:
                next_action_batch_size.append(np.array([0]))

            replay_buffer.store_transition(state, selected_batch, reward, next_state, next_action_batch_size, done)
        
            # Change the state of the environment.
            state = torch.tensor(next_state, dtype=torch.float32).to(device)
            episode_duration += selected_batch
            print("- Selected batch is {}.".format(selected_batch))

        agent_episode_final_acc = batch_env.return_episode_qualities()
        agent_episode_scores.append(agent_episode_final_acc[-1])
        agent_episode_final_precision = batch_env.return_episode_precisions()
        agent_episode_precisions.append(agent_episode_final_precision[-1])
        agent_episode_durations.append(episode_duration)
        
    maximum_epoch_precision = max(agent_episode_precisions)
    minimum_batches_for_the_maximum_epoch_precision = []
    accuracy_for_the_maximum_epoch_precision = []
    for i in range(len(agent_episode_precisions)):
        if agent_episode_precisions[i] == maximum_epoch_precision:
            minimum_batches_for_the_maximum_epoch_precision.append(agent_episode_durations[i])
            accuracy_for_the_maximum_epoch_precision.append(agent_episode_scores[i])
    agent_epoch_precisions.append(maximum_epoch_precision)
    agent_epoch_scores.append(accuracy_for_the_maximum_epoch_precision)
    agent_epoch_durations.append(min(minimum_batches_for_the_maximum_epoch_precision))

    # NEURAL NETWORK UPDATES.
    for update in range(NN_UPDATES_PER_EPOCHS_BATCH_AGENT):
        minibatch = replay_buffer.sample_minibatch(BATCH_SIZE)
        td_error = batch_agent.train(minibatch)
        replay_buffer.update_td_errors(td_error, minibatch.indices)