In [203]:
%pip install Levenshtein


Note: you may need to restart the kernel to use updated packages.


In [204]:
# IMPORTS

import psutil
import json
import os
import random
import torch

from datetime import datetime
from Levenshtein import distance
from jinja2 import Template

In [205]:
# GET COMPUTER SPECS

platform_info = {}

platform_info['physical_cpu_cores'] = psutil.cpu_count(logical=False)
platform_info['total_cpu_cores'] = psutil.cpu_count(logical=True)

def get_available_device():
    """Helper method to find best possible hardware to run
    Returns:
        torch.device used to run experiments.
        str representation of backend.
    """
    # Check if CUDA is available
    if torch.cuda.is_available():
        return torch.device("cuda"), "cuda"

    # Check if ROCm is available
    if torch.version.hip is not None and torch.backends.mps.is_available():
        return torch.device("rocm"), "rocm"

    # Check if MPS (Apple Silicon) is available
    if torch.backends.mps.is_available():
        return torch.device('cpu'), "mps"

    # Fall back to CPU
    return torch.device("cpu"), "cpu"

# Check device info
device, backend = get_available_device()

# Check for GPU-specific details if CUDA or ROCm is available
if device.type == "cuda":
    cuda_device_count = torch.cuda.device_count()
    cuda_device_name = torch.cuda.get_device_name(0)
    cuda_version = torch.version.cuda
elif device.type == "rocm":
    cuda_device_count = torch.cuda.device_count()
    cuda_device_name = torch.cuda.get_device_name(0)
    cuda_version = torch.version.hip
else:
    cuda_device_count = 0
    cuda_device_name = "N/A"
    cuda_version = "N/A"

platform_info['device'] = device.type
platform_info['backend'] = backend
platform_info['cuda_device_count'] = cuda_device_count
platform_info['cuda_device_name'] = cuda_device_name
platform_info['cuda_version'] = cuda_version

# print(json.dumps(platform_info, indent=4))

In [206]:
# POSITIVE PROMPT TEMPLATES

SIMILARITY_DISTANCE = 3
def are_tracks_similar(tracks):
    for i in range(len(tracks)):
        for j in range(i + 1, len(tracks)):
            if distance(tracks[i]['name'], tracks[j]['name']) < SIMILARITY_DISTANCE:
                return True
    return False

def similarity_groups(tracks):
    found_group = [False] * len(tracks)
    groups = []

    for i in range(len(tracks)):
        if found_group[i]:
            continue

        found_group[i] = True
        groups.append([tracks[i]])

        for j in range(i + 1, len(tracks)):
            if found_group[j]:
                continue
            if distance(tracks[i]['name'], tracks[j]['name']) < SIMILARITY_DISTANCE:
                found_group[j] = True

    return groups

pos_prompt_templates = {}

pos_prompt_templates['1-long'] =  Template("""\
Album cover for this album:
Album name : {{ album.name }}
Artist{% if album.artists|length > 1 %}s{% endif %} : {{ album.artists | join(', ') }}
Release Date : {{ album.date }}
Label : {{ album.label }}
Tracks:
{% for track in album.tracks %}- {{ track.name }}\n{% endfor %}
""")

pos_prompt_templates['2-only-tracks'] =  Template("""\
Album cover for these tracks: 
{% for track in album.tracks %}- {{ track.name }}\n{% endfor %}
""")

pos_prompt_templates['3-only-title'] =  Template("""\
Album cover for "{{ album.name }}"
""")

pos_prompt_templates['4-long-with-track-similarity'] =  Template("""\
Album cover for this album:
Album name : {{ album.name }}
Artist{% if album.artists|length > 1 %}s{% endif %} : {{ album.artists | join(', ') }}
Release Date : {{ album.date }}
Label : {{ album.label }}

{% if are_tracks_similar(album.tracks) %} Track format : {% for track in similarity_groups(album.tracks) %}- {{ track.name }}\n{% endfor %}
{% else %} Tracks:
{% for track in album.tracks %}- {{ track.name }}\n{% endfor %}{% endif %}
""")


In [207]:
# NEGATIVE PROMPT TEMPLATES

neg_prompts = {}
neg_prompts['1-no-text'] = "text"

In [208]:
# GET ALBUM DATA

file_id = "" # if need for a specific album, put the file name here

if file_id == "":
    album_files = os.listdir('input/albums')
    random_album_file = random.choice(album_files)
else:
    random_album_file = f'{file_id}.json'

with open(f'input/albums/{random_album_file}', 'r') as file:
    album_data = json.load(file)

In [209]:
# OTHER PARAMETERS

INFERENCE_STEPS = [20, 100]

In [None]:
# CREATE RUN PARAMETERS

run_info = {}
run_info['computer_specs'] = platform_info
run_info['album_id'] = album_data['id']

is_similar = are_tracks_similar(album_data['tracks'])

torch.cuda.empty_cache() # not sure if this is necessary

for pos_key, template in pos_prompt_templates.items():

    # Skip prompt if the tracks are not similar
    if pos_key == '4-long-with-track-similarity' and not is_similar:
        continue

    curr_run_info = run_info.copy()
    curr_run_info['pos_prompt'] = pos_key

    for neg_key, neg_prompt in neg_prompts.items():
        curr_run_info['neg_prompt'] = neg_key

        for step in INFERENCE_STEPS:
            curr_run_info['inference_steps'] = step

    print(json.dumps(curr_run_info, indent=4))

    del curr_run_info

{
    "computer_specs": {
        "physical_cpu_cores": 4,
        "total_cpu_cores": 8,
        "device": "cpu",
        "backend": "cpu",
        "cuda_device_count": 0,
        "cuda_device_name": "N/A",
        "cuda_version": "N/A"
    },
    "album_id": "0OYCfnuIteT2ECqMW8XtwY",
    "pos_prompt": "1-long",
    "neg_prompt": "1-no-text",
    "inference_steps": 100
}
{
    "computer_specs": {
        "physical_cpu_cores": 4,
        "total_cpu_cores": 8,
        "device": "cpu",
        "backend": "cpu",
        "cuda_device_count": 0,
        "cuda_device_name": "N/A",
        "cuda_version": "N/A"
    },
    "album_id": "0OYCfnuIteT2ECqMW8XtwY",
    "pos_prompt": "2-only-tracks",
    "neg_prompt": "1-no-text",
    "inference_steps": 100
}
{
    "computer_specs": {
        "physical_cpu_cores": 4,
        "total_cpu_cores": 8,
        "device": "cpu",
        "backend": "cpu",
        "cuda_device_count": 0,
        "cuda_device_name": "N/A",
        "cuda_version": "N/A"
    },
 