In [4]:
import torch
from experiment_utils import setup_experiment, train, evaluate_model
from pruning import conduct_experiment
import matplotlib.pyplot as plt

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# MLP

In [None]:
import warnings
warnings.filterwarnings('ignore')
min_freqs = [2,20,80]
colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd']  # 5 distinct colors
markers = ['o', 's', '^', 'D', '*']  # 5 different markers

legends = {2: 'big', 20: 'medium', 80: 'small'}

for dataset in ['reviews', 'imdb']:    
    for model in ['bow_mlp']:
        config = {
                'batch_size': 32,
                'seed': 42,
                'num_epochs': 25,
                'learning_rate': 1e-4,
                'device': device,
                'dataset_name': dataset,
                'train_dir': 'Reviews.csv',
                'test_dir': 'Reviews.csv',
                'text_col': 'Text',
                'label_col': 'Score',
                'max_length': 256,
                'num_classes': 5,
                'model_type': model,
                'model_size': 'small',
                'optimizer_type': 'adam',
            }
        fig, ax = plt.subplots(2, 2, figsize=(9.2, 6.2), dpi=300)
        #plt.subplots(2, 2, figsize=(3.544, 2.4), dpi=500
        for min_freq, color, marker in zip(min_freqs, colors, markers):
            config['min_freq']=min_freq
            for wd in [0]:
                torch.manual_seed(42)
                (train_dataloader, test_dataloader), model = setup_experiment(config)
                history = train(
                    model=model,
                    train_dataloader=train_dataloader,
                    val_dataloader=test_dataloader,
                    dataset_name=config['dataset_name'],
                    model_type=config['model_type'],
                    model_size=config['model_size'],
                    num_epochs=config['num_epochs'],
                    optimizer_type=config['optimizer_type'],
                    lr=config['learning_rate'],
                    weight_decay=wd,
                    device=config['device'],
                    save_path='./results/models'
                )
                for prune_emb in [False]:
                    print('===================================')
                    print('===================================')
                    print(f"MODEL: {config['model_type']}; DICT: {config['min_freq']}; PRUNE_EMB: {prune_emb}; WD: {wd}")
                    results_df = conduct_experiment(model, test_dataloader, config['dataset_name'], config['model_type'], config['model_size'],
                                regime='linear_only', prune_embedding=prune_emb, device='cuda', save_path='results/metrics', 
                                weight_decay=wd, min_freq=min_freq, plot=False)
                    # Plot results
                    x_tr = results_df['threshold']
                    y_ac = results_df['accuracy']
                    y_fe = results_df['free_energy'][1:]
                    x_den = results_df['sparsity']
                    y_fe = (y_fe - y_fe.min()) / (y_fe.max() - y_fe.min())
                    print(y_fe)
                    
                    # Plot 1: Accuracy vs Threshold
                    ax[0, 0].plot(x_tr, y_ac, color=color, linewidth=1.5, marker=marker, markersize=6, label=f'vocab_size={legends[min_freq]}')
                    
                    # Plot 2: Accuracy vs Sparsity
                    ax[0, 1].plot(x_den, y_ac, color=color, linewidth=1.5, marker=marker, markersize=6, label=f'vocab_size={legends[min_freq]}')
                    
                    # Plot 3: Free Energy vs Threshold
                    ax[1, 0].plot(x_tr[1:], y_fe, color=color, linewidth=1.5, marker=marker, markersize=6, label=f'vocab_size={legends[min_freq]}')
                    
                    # Plot 4: Free Energy vs Sparsity
                    ax[1, 1].plot(x_den[1:], y_fe, color=color, linewidth=1.5, marker=marker, markersize=6, label=f'vocab_size={legends[min_freq]}')
                    
                    
        ax[0, 0].set_xlabel('Threshold', fontsize=12)
        ax[0, 0].set_xlabel('Threshold', fontsize=12)
        ax[0, 0].set_ylabel('Accuracy (%)', fontsize=12)
        ax[0, 0].grid(True, linestyle='--', alpha=0.5)
        ax[0, 0].text(0.5, -0.3, '(a)', fontsize=12, ha='center', transform=ax[0, 0].transAxes)
        ax[0, 0].legend()

        ax[0, 1].set_xlabel('Sparsity', fontsize=12)
        ax[0, 1].set_ylabel('Accuracy (%)', fontsize=12)
        ax[0, 1].grid(True, linestyle='--', alpha=0.5)
        ax[0, 1].text(0.5, -0.3, '(c)', fontsize=12, ha='center', transform=ax[0, 1].transAxes)
        ax[0, 1].legend()

        ax[1, 0].set_xlabel('Threshold', fontsize=12)
        ax[1, 0].set_ylabel('Free Energy', fontsize=12)
        ax[1, 0].grid(True, linestyle='--', alpha=0.5)
        ax[1, 0].text(0.5, -0.3, '(b)', fontsize=12, ha='center', transform=ax[1, 0].transAxes)
        ax[1, 0].legend()

        ax[1, 1].set_xlabel('Sparsity', fontsize=12)
        ax[1, 1].set_ylabel('Free Energy', fontsize=12)
        ax[1, 1].grid(True, linestyle='--', alpha=0.5)
        ax[1, 1].text(0.5, -0.3, '(d)', fontsize=12, ha='center', transform=ax[1, 1].transAxes)
        ax[1, 1].legend()

        # Adjust spacing
        plt.subplots_adjust(hspace=0.4, wspace=0.4)
        plt.show() 



# Encoder-decoder

In [None]:
import warnings
warnings.filterwarnings('ignore')
min_freqs = [2,20,80]
colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd']  # 5 distinct colors
markers = ['o', 's', '^', 'D', '*']  # 5 different markers

shaded_regions = [
    (0.32, 0.5),    # min_freq=2: sparsity range 0.2-0.4
    (0.25, 0.4),   # min_freq=20: sparsity range 0.3-0.5
    (0.10, 0.19)    # min_freq=80: sparsity range 0.6-0.8
]

for dataset in ['reviews']:
    for prune_emb in [False, True]:    
        for model in ['encoder_decoder']:
            config = {
                    'batch_size': 32,
                    'seed': 42,
                    'num_epochs': 25,
                    'learning_rate': 1e-4,
                    'device': device,
                    'dataset_name': dataset,
                    'train_dir': 'Reviews.csv',
                    'test_dir': 'Reviews.csv',
                    'text_col': 'Text',
                    'label_col': 'Score',
                    'max_length': 256,
                    'num_classes': 5,
                    'model_type': model,
                    'model_size': 'small',
                    'optimizer_type': 'adam',
                }
            fig, ax = plt.subplots(2, 2, figsize=(9.2, 6.2), dpi=300)
            #plt.subplots(2, 2, figsize=(3.544, 2.4), dpi=500
            for min_freq, color, marker, region in zip(min_freqs, colors, markers, shaded_regions):
                config['min_freq']=min_freq
                for wd in [0]:
                    torch.manual_seed(42)
                    (train_dataloader, test_dataloader), model = setup_experiment(config)
                    history = train(
                        model=model,
                        train_dataloader=train_dataloader,
                        val_dataloader=test_dataloader,
                        dataset_name=config['dataset_name'],
                        model_type=config['model_type'],
                        model_size=config['model_size'],
                        num_epochs=config['num_epochs'],
                        optimizer_type=config['optimizer_type'],
                        lr=config['learning_rate'],
                        weight_decay=wd,
                        device=config['device'],
                        save_path='./results/models'
                    )
                    
                    print('===================================')
                    print('===================================')
                    print(f"MODEL: {config['model_type']}; DICT: {config['min_freq']}; PRUNE_EMB: {prune_emb}; WD: {wd}")
                    results_df = conduct_experiment(model, test_dataloader, config['dataset_name'], config['model_type'], config['model_size'],
                                regime='linear_only', prune_embedding=prune_emb, device='cuda', save_path='results/metrics', 
                                weight_decay=wd, min_freq=min_freq, plot=False)
                    # Plot results
                    x_tr = results_df['threshold']
                    y_ac = results_df['accuracy']
                    y_fe = results_df['free_energy']
                    x_den = results_df['sparsity']
                    y_fe = results_df['free_energy'][1:]
                    y_fe = (y_fe - y_fe.min()) / (y_fe.max() - y_fe.min())
                    print(y_fe)
                    
                    ax[0, 0].plot(x_tr, y_ac, color=color, linewidth=1.5, marker=marker, markersize=6, label=f'vocab_size={legends[min_freq]}')
                    
                    # Plot 2: Accuracy vs Sparsity
                    ax[0, 1].plot(x_den, y_ac, color=color, linewidth=1.5, marker=marker, markersize=6, label=f'vocab_size={legends[min_freq]}')
                    
                    # Plot 3: Free Energy vs Threshold
                    ax[1, 0].plot(x_tr[1:], y_fe, color=color, linewidth=1.5, marker=marker, markersize=6, label=f'vocab_size={legends[min_freq]}')
                    
                    # Plot 4: Free Energy vs Sparsity
                    ax[1, 1].plot(x_den[1:], y_fe, color=color, linewidth=1.5, marker=marker, markersize=6, label=f'vocab_size={legends[min_freq]}')
                        
                        
            ax[0, 0].set_xlabel('Threshold', fontsize=12)
            ax[0, 0].set_ylabel('Accuracy (%)', fontsize=12)
            ax[0, 0].grid(True, linestyle='--', alpha=0.5)
            ax[0, 0].text(0.5, -0.3, '(a)', fontsize=12, ha='center', transform=ax[0, 0].transAxes)
            ax[0, 0].legend()

            ax[0, 1].set_xlabel('Sparsity', fontsize=12)
            ax[0, 1].set_ylabel('Accuracy (%)', fontsize=12)
            ax[0, 1].grid(True, linestyle='--', alpha=0.5)
            ax[0, 1].text(0.5, -0.3, '(c)', fontsize=12, ha='center', transform=ax[0, 1].transAxes)
            ax[0, 1].legend()

            ax[1, 0].set_xlabel('Threshold', fontsize=12)
            ax[1, 0].set_ylabel('Free Energy', fontsize=12)
            ax[1, 0].grid(True, linestyle='--', alpha=0.5)
            ax[1, 0].text(0.5, -0.3, '(b)', fontsize=12, ha='center', transform=ax[1, 0].transAxes)
            ax[1, 0].legend()

            ax[1, 1].set_xlabel('Sparsity', fontsize=12)
            ax[1, 1].set_ylabel('Free Energy', fontsize=12)
            ax[1, 1].grid(True, linestyle='--', alpha=0.5)
            ax[1, 1].text(0.5, -0.3, '(d)', fontsize=12, ha='center', transform=ax[1, 1].transAxes)
            ax[1, 1].legend()

            # Adjust spacing
            plt.subplots_adjust(hspace=0.4, wspace=0.4)
            plt.show() 


In [None]:
import warnings
warnings.filterwarnings('ignore')
min_freqs = [2,20,80]
colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd']  # 5 distinct colors
markers = ['o', 's', '^', 'D', '*']  # 5 different markers

shaded_regions = [
    (0.32, 0.5),    # min_freq=2: sparsity range 0.2-0.4
    (0.25, 0.4),   # min_freq=20: sparsity range 0.3-0.5
    (0.10, 0.19)    # min_freq=80: sparsity range 0.6-0.8
]

for dataset in ['reviews']:
    for prune_emb in [False, True]:    
        for model in ['encoder']:
            config = {
                    'batch_size': 32,
                    'seed': 42,
                    'num_epochs': 25,
                    'learning_rate': 1e-4,
                    'device': device,
                    'dataset_name': dataset,
                    'train_dir': 'Reviews.csv',
                    'test_dir': 'Reviews.csv',
                    'text_col': 'Text',
                    'label_col': 'Score',
                    'max_length': 256,
                    'num_classes': 5,
                    'model_type': model,
                    'model_size': 'small',
                    'optimizer_type': 'adam',
                }
            fig, ax = plt.subplots(2, 2, figsize=(9.2, 6.2), dpi=300)
            #plt.subplots(2, 2, figsize=(3.544, 2.4), dpi=500
            for min_freq, color, marker, region in zip(min_freqs, colors, markers, shaded_regions):
                config['min_freq']=min_freq
                for wd in [0]:
                    torch.manual_seed(42)
                    (train_dataloader, test_dataloader), model = setup_experiment(config)
                    history = train(
                        model=model,
                        train_dataloader=train_dataloader,
                        val_dataloader=test_dataloader,
                        dataset_name=config['dataset_name'],
                        model_type=config['model_type'],
                        model_size=config['model_size'],
                        num_epochs=config['num_epochs'],
                        optimizer_type=config['optimizer_type'],
                        lr=config['learning_rate'],
                        weight_decay=wd,
                        device=config['device'],
                        save_path='./results/models'
                    )
                    
                    print('===================================')
                    print('===================================')
                    print(f"MODEL: {config['model_type']}; DICT: {config['min_freq']}; PRUNE_EMB: {prune_emb}; WD: {wd}")
                    results_df = conduct_experiment(model, test_dataloader, config['dataset_name'], config['model_type'], config['model_size'],
                                regime='linear_only', prune_embedding=prune_emb, device='cuda', save_path='results/metrics', 
                                weight_decay=wd, min_freq=min_freq, plot=False)
                    # Plot results
                    x_tr = results_df['threshold']
                    y_ac = results_df['accuracy']
                    y_fe = results_df['free_energy']
                    x_den = results_df['sparsity']
                    y_fe = results_df['free_energy'][1:]
                    y_fe = (y_fe - y_fe.min()) / (y_fe.max() - y_fe.min())
                    print(y_fe)
                    
                    # Plot 1: Accuracy vs Threshold
                    ax[0, 0].plot(x_tr, y_ac, color=color, linewidth=1.5, marker=marker, markersize=6, label=f'vocab_size={legends[min_freq]}')
                    
                    # Plot 2: Accuracy vs Sparsity
                    ax[0, 1].plot(x_den, y_ac, color=color, linewidth=1.5, marker=marker, markersize=6, label=f'vocab_size={legends[min_freq]}')
                    
                    # Plot 3: Free Energy vs Threshold
                    ax[1, 0].plot(x_tr[1:], y_fe, color=color, linewidth=1.5, marker=marker, markersize=6, label=f'vocab_size={legends[min_freq]}')
                    
                    # Plot 4: Free Energy vs Sparsity
                    ax[1, 1].plot(x_den[1:], y_fe, color=color, linewidth=1.5, marker=marker, markersize=6, label=f'vocab_size={legends[min_freq]}')
                        
                        
            ax[0, 0].set_xlabel('Threshold', fontsize=12)
            ax[0, 0].set_ylabel('Accuracy (%)', fontsize=12)
            ax[0, 0].grid(True, linestyle='--', alpha=0.5)
            ax[0, 0].text(0.5, -0.3, '(a)', fontsize=12, ha='center', transform=ax[0, 0].transAxes)
            ax[0, 0].legend()

            ax[0, 1].set_xlabel('Sparsity', fontsize=12)
            ax[0, 1].set_ylabel('Accuracy (%)', fontsize=12)
            ax[0, 1].grid(True, linestyle='--', alpha=0.5)
            ax[0, 1].text(0.5, -0.3, '(c)', fontsize=12, ha='center', transform=ax[0, 1].transAxes)
            ax[0, 1].legend()

            ax[1, 0].set_xlabel('Threshold', fontsize=12)
            ax[1, 0].set_ylabel('Free Energy', fontsize=12)
            ax[1, 0].grid(True, linestyle='--', alpha=0.5)
            ax[1, 0].text(0.5, -0.3, '(b)', fontsize=12, ha='center', transform=ax[1, 0].transAxes)
            ax[1, 0].legend()

            ax[1, 1].set_xlabel('Sparsity', fontsize=12)
            ax[1, 1].set_ylabel('Free Energy', fontsize=12)
            ax[1, 1].grid(True, linestyle='--', alpha=0.5)
            ax[1, 1].text(0.5, -0.3, '(d)', fontsize=12, ha='center', transform=ax[1, 1].transAxes)
            ax[1, 1].legend()

            # Adjust spacing
            plt.subplots_adjust(hspace=0.4, wspace=0.4)
            plt.show() 


In [None]:
import warnings
warnings.filterwarnings('ignore')
min_freqs = [2,20,80]
colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd']  # 5 distinct colors
markers = ['o', 's', '^', 'D', '*']  # 5 different markers

shaded_regions = [
    (0.32, 0.5),    # min_freq=2: sparsity range 0.2-0.4
    (0.25, 0.4),   # min_freq=20: sparsity range 0.3-0.5
    (0.10, 0.19)    # min_freq=80: sparsity range 0.6-0.8
]

for dataset in ['imdb']:
    for prune_emb in [False, True]:    
        for model in ['encoder']:
            config = {
                    'batch_size': 32,
                    'seed': 42,
                    'num_epochs': 25,
                    'learning_rate': 1e-4,
                    'device': device,
                    'dataset_name': dataset,
                    'train_dir': 'Reviews.csv',
                    'test_dir': 'Reviews.csv',
                    'text_col': 'Text',
                    'label_col': 'Score',
                    'max_length': 256,
                    'num_classes': 5,
                    'model_type': model,
                    'model_size': 'small',
                    'optimizer_type': 'adam',
                }
            fig, ax = plt.subplots(2, 2, figsize=(9.2, 6.2), dpi=300)
            #plt.subplots(2, 2, figsize=(3.544, 2.4), dpi=500
            for min_freq, color, marker, region in zip(min_freqs, colors, markers, shaded_regions):
                config['min_freq']=min_freq
                for wd in [0]:
                    torch.manual_seed(42)
                    (train_dataloader, test_dataloader), model = setup_experiment(config)
                    history = train(
                        model=model,
                        train_dataloader=train_dataloader,
                        val_dataloader=test_dataloader,
                        dataset_name=config['dataset_name'],
                        model_type=config['model_type'],
                        model_size=config['model_size'],
                        num_epochs=config['num_epochs'],
                        optimizer_type=config['optimizer_type'],
                        lr=config['learning_rate'],
                        weight_decay=wd,
                        device=config['device'],
                        save_path='./results/models'
                    )
                    
                    print('===================================')
                    print('===================================')
                    print(f"MODEL: {config['model_type']}; DICT: {config['min_freq']}; PRUNE_EMB: {prune_emb}; WD: {wd}")
                    results_df = conduct_experiment(model, test_dataloader, config['dataset_name'], config['model_type'], config['model_size'],
                                regime='linear_only', prune_embedding=prune_emb, device='cuda', save_path='results/metrics', 
                                weight_decay=wd, min_freq=min_freq, plot=False)
                    # Plot results
                    x_tr = results_df['threshold']
                    y_ac = results_df['accuracy']
                    y_fe = results_df['free_energy']
                    x_den = results_df['sparsity']
                    y_fe = results_df['free_energy'][1:]
                    y_fe = (y_fe - y_fe.min()) / (y_fe.max() - y_fe.min())
                    print(y_fe)
                    
                    # Plot 1: Accuracy vs Threshold
                    ax[0, 0].plot(x_tr, y_ac, color=color, linewidth=1.5, marker=marker, markersize=6, label=f'min_freq={min_freq}')
                    
                    # Plot 2: Accuracy vs Sparsity
                    ax[0, 1].plot(x_den, y_ac, color=color, linewidth=1.5, marker=marker, markersize=6, label=f'min_freq={min_freq}')
                    
                    # Plot 3: Free Energy vs Threshold
                    ax[1, 0].plot(x_tr[1:], y_fe, color=color, linewidth=1.5, marker=marker, markersize=6, label=f'min_freq={min_freq}')
                    
                    # Plot 4: Free Energy vs Sparsity
                    ax[1, 1].plot(x_den[1:], y_fe, color=color, linewidth=1.5, marker=marker, markersize=6, label=f'min_freq={min_freq}')

                        
                        
            ax[0, 0].set_xlabel('Threshold', fontsize=12)
            ax[0, 0].set_xlabel('Threshold', fontsize=12)
            ax[0, 0].set_ylabel('Accuracy (%)', fontsize=12)
            ax[0, 0].grid(True, linestyle='--', alpha=0.5)
            ax[0, 0].legend()

            ax[0, 1].set_xlabel('Sparsity', fontsize=12)
            ax[0, 1].set_ylabel('Accuracy (%)', fontsize=12)
            ax[0, 1].grid(True, linestyle='--', alpha=0.5)
            ax[0, 1].legend()

            ax[1, 0].set_xlabel('Threshold', fontsize=12)
            ax[1, 0].set_ylabel('Free Energy', fontsize=12)
            ax[1, 0].grid(True, linestyle='--', alpha=0.5)
            ax[1, 0].legend()

            ax[1, 1].set_xlabel('Sparsity', fontsize=12)
            ax[1, 1].set_ylabel('Free Energy', fontsize=12)
            ax[1, 1].grid(True, linestyle='--', alpha=0.5)
            ax[1, 1].legend()

            # Adjust spacing
            plt.subplots_adjust(hspace=0.4, wspace=0.4)
            plt.show() 


In [None]:
import warnings
from torch import nn
warnings.filterwarnings('ignore')
for dataset in ['imdb', 'reviews']:
    for model in ['pretrained_transformer']:
        config = {
                'batch_size': 32,
                'seed': 42,
                'num_epochs': 10,
                'learning_rate': 1e-4,
                'device': device,
                'dataset_name': dataset,
                'train_dir': 'Reviews.csv',
                'test_dir': 'Reviews.csv',
                'text_col': 'Text',
                'label_col': 'Score',
                'max_length': 256,
                'num_classes': 5,
                'model_type': model,
                'model_size': 'small',
                'optimizer_type': 'adam',
                'pretrained_model_name': "huawei-noah/TinyBERT_General_4L_312D"
            }
        for min_freq in [0]:
            config['min_freq']=min_freq
            for wd in [0]:
                torch.manual_seed(42)
                (train_dataloader, test_dataloader), model = setup_experiment(config)
                model = model[0].to(device)
                model.classifier = nn.Linear(model.bert.pooler.dense.out_features, 5)
                history = train(
                    model=model,
                    train_dataloader=train_dataloader,
                    val_dataloader=test_dataloader,
                    dataset_name=config['dataset_name'],
                    model_type=config['model_type'],
                    model_size=config['model_size'],
                    num_epochs=config['num_epochs'],
                    optimizer_type=config['optimizer_type'],
                    lr=config['learning_rate'],
                    weight_decay=wd,
                    device=config['device'],
                    save_path='./results/models'
                )
                for prune_emb in [True]:
                    print('===================================')
                    print('===================================')
                    print(f"MODEL: {config['model_type']}; DICT: {config['min_freq']}; PRUNE_EMB: {prune_emb}; WD: {wd}")
                    conduct_experiment(model, test_dataloader, config['dataset_name'], config['model_type'], config['model_size'],
                                regime='linear_only', prune_embedding=prune_emb, device='cuda', save_path='results/metrics', weight_decay=wd)

# reproduction Cifar

In [None]:
from torch import nn 
from torch.nn import functional as F
from cv_models import * #importing all models


import warnings
warnings.filterwarnings('ignore')
for model in ['mlp']:
    config = {
            'batch_size': 32,
            'seed': 42,
            'num_epochs': 25,
            'learning_rate': 1e-4,
            'device': device,
            'dataset_name': 'cifar10',
            #'train_dir': "C:/Users/LEGION/Projects/codes + results-20250429T135614Z-001/aclImdb/aclImdb/train",
            #'test_dir': "C:/Users/LEGION/Projects/codes + results-20250429T135614Z-001/aclImdb/aclImdb/test",
            #'text_col': 'Text',
            #'label_col': 'Score',
            'max_length': 256,
            'num_classes': 5,
            'model_type': model,
            'model_size': 'small',
            'optimizer_type': 'adam',
        }
    for min_freq in [0]:
        config['min_freq']=min_freq
        for wd in [0]:
            torch.manual_seed(42)
            (train_dataloader, test_dataloader), model = setup_experiment(config, load_model=False, model_path='./results/models\cifar10_mlp_small_best.pth')
            model = torch.load('./results/models/cifar10_mlp_small_best.pth', weights_only=False)['net'].to(device)
            model.eval
            for prune_emb in [False]:
                print('===================================')
                print('===================================')
                print(f"MODEL: {config['model_type']}; DICT: {config['min_freq']}; PRUNE_EMB: {prune_emb}; WD: {wd}")
                conduct_experiment(model, test_dataloader, config['dataset_name'], config['model_type'], config['model_size'],
                            regime='linear_only', prune_embedding=prune_emb, device='cuda', save_path='results/metrics_2', weight_decay=wd)

# Time measurement

In [None]:
import warnings
import time
warnings.filterwarnings('ignore')
for model in ['pretrained_transformer']:
    config = {
            'batch_size': 32,
            'seed': 42,
            'num_epochs': 25,
            'learning_rate': 1e-4,
            'device': device,
            'dataset_name': 'reviews',
            'train_dir': 'Reviews.csv',
            'test_dir': 'Reviews.csv',
            'text_col': 'Text',
            'label_col': 'Score',
            'max_length': 256,
            'num_classes': 5,
            'model_type': model,
            'model_size': 'small',
            'optimizer_type': 'adam',
            'pretrained_model_name': "huawei-noah/TinyBERT_General_4L_312D"
        }
    for min_freq in [2]:
        config['min_freq']=min_freq
        for wd in [0]:
            torch.manual_seed(42)
            (train_dataloader, test_dataloader), model = setup_experiment(config)
            if config['model_type'] == 'pretrained_transformer':
                model = model[0].to(device)
            history = train(
                model=model,
                train_dataloader=train_dataloader,
                val_dataloader=test_dataloader,
                dataset_name=config['dataset_name'],
                model_type=config['model_type'],
                model_size=config['model_size'],
                num_epochs=config['num_epochs'],
                optimizer_type=config['optimizer_type'],
                lr=config['learning_rate'],
                weight_decay=wd,
                device=config['device'],
                save_path='./results/models'
            )
            start = time.time()
            for prune_emb in [True]:
                print('===================================')
                print('===================================')
                print(f"MODEL: {config['model_type']}; DICT: {config['min_freq']}; PRUNE_EMB: {prune_emb}; WD: {wd}")
                conduct_experiment(model, test_dataloader, config['dataset_name'], config['model_type'], config['model_size'],
                            regime='linear_only', prune_embedding=prune_emb, device='cuda', save_path='results/metrics_time_measurments', weight_decay=wd,
                            accuracy_only=True, plot = False)
                finish = time.time()
                print(f"Time taken for {config['model_type']}: {finish - start} seconds")


In [None]:
import warnings
from torch import nn
import time
warnings.filterwarnings('ignore')
for model in ['bow_mlp', 'encoder_decoder', 'encoder', 'pretrained_transformer']:
    config = {
            'batch_size': 32,
            'seed': 42,
            'num_epochs': 25,
            'learning_rate': 1e-4,
            'device': device,
            'dataset_name': 'imdb',
            'train_dir': 'aclImdb/aclImdb/train',
            'test_dir': 'aclImdb/aclImdb/test',
            'text_col': 'Text',
            'label_col': 'Score',
            'max_length': 256,
            'num_classes': 5,
            'model_type': model,
            'model_size': 'small',
            'optimizer_type': 'adam',
            'pretrained_model_name': "huawei-noah/TinyBERT_General_4L_312D"
        }
    for min_freq in [2]:
        config['min_freq']=min_freq
        for wd in [0]:
            torch.manual_seed(42)
            (train_dataloader, test_dataloader), model = setup_experiment(config)
            if config['model_type'] == 'pretrained_transformer':
                model = model[0].to(device)
                model.classifier = nn.Linear(model.classifier.in_features, config['num_classes'])
            history = train(
                model=model,
                train_dataloader=train_dataloader,
                val_dataloader=test_dataloader,
                dataset_name=config['dataset_name'],
                model_type=config['model_type'],
                model_size=config['model_size'],
                num_epochs=config['num_epochs'],
                optimizer_type=config['optimizer_type'],
                lr=config['learning_rate'],
                weight_decay=wd,
                device=config['device'],
                save_path='./results/models'
            )
            start = time.time()
            for prune_emb in [True]:
                print('===================================')
                print('===================================')
                print(f"MODEL: {config['model_type']}; DICT: {config['min_freq']}; PRUNE_EMB: {prune_emb}; WD: {wd}")
                conduct_experiment(model, test_dataloader, config['dataset_name'], config['model_type'], config['model_size'],
                            regime='linear_only', prune_embedding=prune_emb, device='cuda', save_path='results/metrics_time_measurments_energy', weight_decay=wd,
                            energy_only=True, plot = False)
                finish = time.time()
                print(f"Time taken for {config['model_type']}: {finish - start} seconds")
