1. Compression Rate & Animation

In [5]:
# one character takes 1 byte (UTF-8 encoding)
# a token roughly contains 4 characters 
# so one token takes 4 bytes
# Llama3-8B is trained with 15T tokens, which is 60 trillion bits, or xx TB

def calculate_compression_rate(num_tokens: int = 15e12, num_params: int = 8e9, model_name: str = "Llama3-8B"):
    """ 
    Function to calculate compression rate of a LLM 
    """
    
    # Calculate the size in TB
    bytes_per_token = 4
    total_bytes = num_tokens * bytes_per_token
    total_tb = total_bytes / (1024**4)  # Convert bytes to TB

    print(f"{model_name} is trained with {num_tokens:.2e} tokens, which is {total_bytes:.2e} bytes, or {total_tb:.2f} TB")
    
    # Calculate size of model of llama3-8b
    bytes_per_param = 4  # Assuming 32-bit float for each parameter
    model_size_bytes = num_params * bytes_per_param
    model_size_gb = model_size_bytes / (1024**3)  # Convert bytes to GB

    print(f"The size of the {model_name} model is approximately {model_size_gb:.2f} GB")

    # Calculate compression rate
    compression_rate = total_bytes / model_size_bytes

    print(f"The compression rate is approximately {compression_rate:.2f}")

    return compression_rate, total_bytes, model_size_bytes


compression_rate, data_size, model_size = calculate_compression_rate()


Llama3-8B is trained with 1.50e+13 tokens, which is 6.00e+13 bytes, or 54.57 TB
The size of the Llama3-8B model is approximately 29.80 GB
The compression rate is approximately 1875.00


In [14]:
import matplotlib.animation as animation
from matplotlib.animation import FuncAnimation
import matplotlib.pyplot as plt
import numpy as np

def animate_shrinking_circle():
    radius_ratio = np.sqrt(1875)
    r_end = 1  # radius of smaller circle 
    r_start = np.sqrt(radius_ratio)

    fig, ax = plt.subplots(figsize=(8, 8))
    ax.set_xlim(-r_start, r_start)
    ax.set_ylim(-r_start, r_start)
    ax.set_aspect('equal')
    ax.axis('off')

    fig.patch.set_facecolor('white')

    circle = plt.Circle((0, 0), r_start, facecolor='white', edgecolor='black')
    ax.add_artist(circle)

    # Save start frame
    plt.savefig('start_frame.png')

    def update(frame):
        t = frame / 99  # 100 frames total
        current_radius = r_start - (r_start - r_end) * t
        circle.set_radius(current_radius)
        
        color_value = 1 - t
        circle.set_facecolor((color_value, color_value, color_value))
        
        return circle,

    anim = FuncAnimation(fig, update, frames=100, interval=50, blit=True)
    
    # Save the animation as a gif
    anim.save('shrinking_circle.gif', writer='pillow', fps=20)
    
    
    # Save end frame
    update(99)
    plt.savefig('end_frame.png')
    
    plt.close(fig)  # Close the figure to prevent display

# Call the function to create and save the animation and frames
animate_shrinking_circle()


2. What is optimal compression with respect to current loss function? Assuming Gaussian data distribution, KL-loss effectively tells us to "ignore" long-tail data. 

In [None]:
# Assume GT data is Gaussian distributed, Data model is also a Gaussian distribution (learnable hyper-parameter)
# I want to show-case that "accuracy" on the long-tail is ignored under the KL-divergence loss function 
# By which I mean, "accuracy" is more encouraged on the "center" of distribution (near-mean)
# So I would like to have an animation to show-case the change of loss function, under these two cases, GT don't change, data model changes


In [104]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from scipy.stats import norm
from scipy.special import kl_div
from matplotlib.colors import to_rgba
import matplotlib.gridspec as gridspec
import matplotlib.colors as mcolors


def gaussian(x, mu, sigma):
    return np.exp(-((x - mu) ** 2) / (2 * sigma ** 2)) / (sigma * np.sqrt(2 * np.pi))

def kl_divergence(p, q):
    return np.sum(np.where(p != 0, p * np.log(p / q), 0))


# Lists to store data for 3D plot
relative_tail_kls = []
center_kls = []
total_kls = []

# Variables to store the minimal KL point
min_kl_frame = 0
min_kl_value = float('inf')

def calculate_kl(frame):
    model_mu = 0.5 * np.cos(frame / 10)
    model_sigma = 1 + 0.5 * np.sin(frame / 10)
    model_dist = gaussian(x, model_mu, model_sigma)
    kl = kl_div(gt_dist, model_dist)
    center_range = np.abs(x) < 2
    center_kl = np.sum(kl[center_range])
    tail_kl = np.sum(kl[~center_range])
    total_kl = center_kl + tail_kl
    relative_tail_kl = tail_kl / total_kl
    return total_kl, center_kl, tail_kl, relative_tail_kl, model_dist

# First loop to find global minimum
for frame in range(100):
    total_kl, _, _, _, _ = calculate_kl(frame)
    if total_kl < min_kl_value:
        min_kl_value = total_kl
        min_kl_frame = frame

print(f"Global minimum KL divergence found at frame {min_kl_frame}")

# Set up the plot
fig = plt.figure(figsize=(16, 10))  # Adjusted figure size for better proportions
gs = gridspec.GridSpec(1, 12, figure=fig)  # 1 row, 4 columns

ax1 = fig.add_subplot(gs[0, :12])  # Span 3 columns
ax2 = fig.add_subplot(gs[0, 8:11], projection='3d')  # Take the last column

plt.subplots_adjust(wspace=0.9)  # Adjust space between subplots


# Ground Truth (GT) distribution parameters
gt_mu, gt_sigma = 0, 1

# X-axis range
x = np.linspace(-4, 4, 1000)

# GT distribution
gt_dist = gaussian(x, gt_mu, gt_sigma)

# Initialize plots
ax1.plot(x, gt_dist, label='Ground Truth', color='black', linewidth=2)
model_plot, = ax1.plot([], [], label='Data Model', color='red', linewidth=2)
ax1.set_ylim(0, 0.5)
ax1.set_title('Data Model vs. Distribution', fontsize=16)
ax1.legend(fontsize=12)
ax1.set_xlabel('x', fontsize=12)
ax1.set_ylabel('Probability Density', fontsize=12)

# Initialize 3D plot
# scatter = ax2.scatter([], [], [], c=[], cmap='viridis')
ax2.set_xlabel('Tail KL')
ax2.set_ylabel('Center KL')
ax2.set_zlabel('KL')
# ax2.set_title('KL Divergence', fontsize=16)

# Add text annotations
center_text = ax1.text(0.05, 0.95, '', transform=ax1.transAxes, fontsize=14, verticalalignment='top')
tail_text = ax1.text(0.05, 0.85, '', transform=ax1.transAxes, fontsize=14, verticalalignment='top')
kl_text = ax1.text(0.05, 0.75, '', transform=ax1.transAxes, fontsize=14, verticalalignment='top', color="green")


# Lists to store data for 3D plot
relative_tail_kls = []
center_kls = []
total_kls = []


# Animation update function
def update(frame):
    global relative_tail_kls, center_kls, total_kls
    
    total_kl, center_kl, tail_kl, relative_tail_kl, model_dist = calculate_kl(frame)
    
    # Update distribution plot
    model_plot.set_data(x, model_dist)
    
    # Update text annotations
    color_center = "red" if (center_kl / total_kl) > (tail_kl / total_kl) else "green"
    color_tail = "green" if (center_kl / total_kl) > (tail_kl / total_kl) else "red"
    color_kl = "green" if total_kl < 25 else "red"
    kl_str = "Low Total KL" if total_kl < 25 else "High Total KL"
    
    center_text.set_text(f'Center relative KL: {center_kl/total_kl*100:.1f}%')
    center_text.set_color(color_center)
    tail_text.set_text(f'Tail relative KL: {tail_kl/total_kl*100:.1f}%')
    tail_text.set_color(color_tail)
    kl_text.set_text(f'{kl_str}: {total_kl:.4f}')
    kl_text.set_color(color_kl)
    
    # Remove existing fill_between plots
    for coll in ax1.collections:
        coll.remove()
    
    # Highlight center and tail areas
    center_mask = (x >= -2) & (x <= 2)
    ax1.fill_between(x[center_mask], 0, gt_dist[center_mask], alpha=0.2, color='green', label='Center')
    
    left_tail_mask = (x >= -4) & (x <= -2)
    right_tail_mask = (x >= 2) & (x <= 4)
    ax1.fill_between(x[left_tail_mask], 0, gt_dist[left_tail_mask], alpha=0.2, color='orange', label="Tail")
    ax1.fill_between(x[right_tail_mask], 0, gt_dist[right_tail_mask], alpha=0.2, color='orange')
    
    if frame == 0:
        ax1.legend(fontsize=12)
    
    # Update 3D plot data
    relative_tail_kls.append(relative_tail_kl)
    center_kls.append(center_kl)
    total_kls.append(total_kl)
    
    # Color-deciding function
    def get_color(kl):
        if kl <= 10:
            return to_rgba('green', alpha=1.0)
        elif kl >= 100:
            return to_rgba('red', alpha=1.0)
        else:
            t = (kl - 10) / 90
            green = np.array(mcolors.to_rgb('green'))
            red = np.array(mcolors.to_rgb('red'))
            rgb = green * (1 - t) + red * t
            return tuple(rgb.tolist() + [1.0])
    
    colors = [get_color(kl) for kl in total_kls]
    
    ax2.clear()
    ax2.scatter(relative_tail_kls, center_kls, total_kls, c=colors, s=50)
    
    ax2.set_xlim(0, max(relative_tail_kls))
    ax2.set_ylim(0, max(center_kls))
    ax2.set_zlim(0, max(total_kls))
    
    ax2.set_xlabel('Tail KL')
    ax2.set_ylabel('Center KL')
    ax2.set_zlabel('KL')
    
    ax2.set_title(f'Frame {frame}')
    
    return model_plot, center_text, tail_text


# Create animation
anim = FuncAnimation(fig, update, frames=89, interval=100, blit=False)

# Save animation
anim.save('kl_divergence_animation.gif', writer='pillow', fps=10)

# Save the last frame
update(88)  # Call update with the last frame number (89 - 1)
plt.savefig('kl_divergence_last_frame.png')

plt.close(fig)  # Close the figure to prevent display

print("Animation saved as 'kl_divergence_animation.gif'")

Global minimum KL divergence found at frame 27
Animation saved as 'kl_divergence_animation.gif'
