# GAN for album cover generation

### Summary

This project creates a Generative Adversarial Network (GAN) for album cover generation based on the Spotify dataset.
It is based on [this paper](https://ryanmcconville.com/publications/AlbumCoverGenerationFromGenreTags.pdf) and the ultimate objective is to generate album images based on music genres.

There will be some intermediate steps on the process:
1 - Generate a data set
2 - Create a genre classifier
3 - Create an album image GAN
4 - Create a GAN that generates albums based on genre

In [43]:
!pip3 install spotipy



In [57]:
import spotipy
import time
import random
from spotipy.oauth2 import SpotifyClientCredentials

## Loading spotify credentials

You need to configure your spotify credentials as a json file.\

Your json file should be called **spotify_credentials.json** and look like this:

```json
{
    "client_id": "your_client_id",
    "client_secret": "your_client_secret"
}
```

More details on how to use json files as configuration files can be found [here](https://martin-thoma.com/configuration-files-in-python/#:~:text=configuration%20handling%3A%20cfg_load-,Python%20Configuration%20File,to%20avoid%20uploading%20it%20accidentally)

In [91]:
import json

with open("spotify_credentials.json") as json_data_file:
    data = json.load(json_data_file)

client_id = data["client_id"]
client_secret = data["client_secret"]

In [92]:
client_credentials_manager = SpotifyClientCredentials(client_id, client_secret)
sp = spotipy.Spotify(client_credentials_manager=client_credentials_manager)

## Creating a list of images

There's no real api to get all album images of a `genre` in Spotify.\
We need to go around that limitation by assuming that some spotify categories behave as genres.

From there we can use:
- Get a list of categories
- Get all the playlists for a category
- Get all songs for a playlist
- Get an album image for every song

In [93]:
categories = [
    "rock",
    "pop",
    "classical",
    "hiphop",
    "country",
    "latin",
    "edm_dance",
    "jazz"
]

In [94]:
def get_playlists(
    category,
    country = "US"
):
    playlists = list()
    
    offset = 0
    while True: #get more playlists
        playlistResponse = sp.category_playlists(category, limit = 20, offset = offset, country = country)['playlists']
        playlists.extend(playlistResponse['items'])
        
        if playlistResponse['next'] == None:
            break
        else:
            offset += 20
    
    return playlists

In [95]:
def get_tracks(playlist):
    tracks = list()
    
    offset = 0
    while True: #get more playlists
        tracksResponse = sp.playlist_tracks(playlist['id'], limit=100, offset = offset)
        tracks.extend(tracksResponse['items'])
        
        if tracksResponse['next'] == None:
            break
        else:
            offset += 100
    
    return tracks

## Downloading all the images

arguments of the function download_url:

<ul>
<li>url (str): URL to download file from</li>
<li>root (str): Directory to place downloaded file in</li>
<li>filename (str, optional): Name to save the file under. If None, use the basename of the URL</li>
</ul>


In [96]:
import urllib
import os

def download_url(url, root, filename=None):
    root = os.path.expanduser(root)
    if not filename:
        filename = os.path.basename(url)
    fpath = os.path.join(root, filename  + "." + "png")

    os.makedirs(root, exist_ok=True)

    try:
        print('Downloading ' + url + ' to ' + fpath)
        urllib.request.urlretrieve(url, fpath)
    except (urllib.error.URLError, IOError) as e:
        if url[:5] == 'https':
            url = url.replace('https:', 'http:')
            print('Failed download. Trying https -> http instead.'
                    ' Downloading ' + url + ' to ' + fpath)
            urllib.request.urlretrieve(url, fpath)

In [99]:
def get_album_images(category):
    print ("Downloading playlists")
    playlists = get_playlists(category)
    print ("Downloaded {} playlists".format(len(playlists)))
    
    print ("Downloading tracks")
    tracks = list()
    for playlist in playlists:
        tracks.extend(get_tracks(playlist))
    print ("Downloaded {} tracks".format(len(tracks)))
    
    allImages = set()
    for track in tracks:
        trackData = track['track']
        if (trackData == None): continue
        
        album = trackData['album']        
        if (album == None): continue
        
        trackImages = album['images']
        if (trackImages == None or len(trackImages) == 0): continue
        
        imageUrl = trackImages[0]['url']
        
        if(imageUrl != None):
            allImages.add(imageUrl)
            download_url(imageUrl,"~/images/country")
    
    return allImages

In [None]:
#len(get_album_images("latin"))
len(get_album_images("country"))

Downloading playlists
Downloaded 77 playlists
Downloading tracks
