# New Years Resolutions

And so, here we are again. The calendar flips, and suddenly it’s January—time for that strange, almost ritualistic exercise in self-reinvention we call New Year’s resolutions. You know the drill: the vaguely nauseating societal expectation that you will, at this arbitrary cosmic checkpoint, decide to overhaul your life. Or at least pretend to, so the well-meaning but oppressively curious questioners (“So, any resolutions this year?”) can have their moment of mild but undeniable judgment while you scramble to articulate something—anything—that sounds simultaneously profound and achievable, despite knowing deep down that life is already…well, fine. Not perfect, but fine.

But what if, instead of passively dreading the resolution inquisition, you armed yourself? Not with actual resolutions, per se—because let’s be honest, who has the energy for actual follow-through?—but with a set of wry, irreverent, possibly brilliant ideas drawn from the great chaotic hive mind of the internet. Specifically, from nearly 5,000 New Year’s resolution tweets, harvested and distilled for your convenience. Not so much resolutions as ammunition. A way to deflect, charm, and outwit your interrogators, leaving them dazzled and perhaps even envious of your meta-level resolution game.

Ready? Let’s dive in.

In [43]:
import pandas as pd

prepped_data_url = 'https://www.dropbox.com/scl/fi/lw47ojiic0mzz9lp1xvl9/new_years_resolutions_prepped.parquet?rlkey=mc362g8s5x6bc3zqfhslf28oo&dl=1'

df = pd.read_parquet(prepped_data_url)

# have a little peek:
print(f"{df.shape}")
df.iloc[0]

(4723, 31)


tweet_created                                       2014-12-21 16:11:00
tweet_text            #NewYearsResolution to not put the parking lot...
tweet_category                                                    Humor
tweet_topics          Humor about Personal Growth and Interests Reso...
tweet_location                                       City of Angels, CA
tweet_state                                                          CA
tweet_region                                                       West
user_timezone                                Pacific Time (US & Canada)
user_gender                                                        male
retweet_count                                                       NaN
text                  to not put the parking lot ticket directly in ...
topics_cluster_7                                                      2
_id                                                                   0
pca_x                                                          0

Fields explanations:
* Up to `retweet_count` are the fields of the original base data, which can be [found here](https://github.com/aj-menon/Maven-Analytics/tree/master/2015_new_years_resolution_tweets), along with some further details. 
* The `_id` was added because some visualization tools require an integer unique id.
* The `topics_cluster_7` was added by clustering the semantic embeddings of the `tweet_topics` text into 7 clusters (with the k-means method). This was done because `tweet_topics` alone has 115 unique values, which are a bit too many to use as a categorical. Seven is a more reasonable number for that. More on this in the data preparation section.
* The `text` field contains the `tweet_topics` where we removed all hashtag words (# followed with alphanumericals and underscore)
* All the `*_x` and `*_y` fields are simply planar projections of the `tweet_text` embeddings. All of these, except `tsne` and `umap`, are linear projections. That is, both x and y are obtained by linear transformation. So you can think of all these linear projections as different shadows of the 1500+ dimensional vector representation of the tweets onto different planes in that space. 

## Visualize

Here, we'll look at different scatter plots of the planar projections of the (embeddings of the) tweets, using different fields to control the color of the points. 
When dealing with text data, the usual trick is to transform the text segments into fixed size numerical "feature" vectors. 
This is what most natural language processing techniques have been doing for years, with techniques like bag-of-words, tf-idf, word2vec, etc.
Recently, the field has been revolutionized by the introduction of transformers, which are neural networks that can be trained to transform text into 
fixed size numerical vectors. Nowadays, in the new AI age sparked by `ChatGPT`, you can't throw a stone without hitting a transformer-based model that does 
some kind of text processing. 

Here, we used the OpenAI "text-embedding-3-small" embedding model to transform the text of the tweets into fixed size numerical vectors.
This model will output a 1536-dimensional vector for each tweet.
Now, how do we visualize these 1536-dimensional vectors?
Well, we can't. We can't even **actually** visualize 3D vectors, let alone 1536D vectors. 
But the same way that we manage to display 3D vectors by projecting them onto the 2D plane of a compute screen can be used here. 

That said, there are many ways we can project 3D vectors onto a 2D plane.
We'll go through a few of them here, also using different fields to control the color of the points.

We'll be using `cosmograph` for our scatter plot needs, since this will give us the ability to interact, share, etc.

In [31]:
from cosmograph import cosmo  # pip install cosmograph

Let's have a look at what fields have less than 200 unique values, to get a sense of what we might use to control the characteristics of our scatter plot points. 
(We might want to have no more than 10 unique values if we want to color by the value of that field.)

In [32]:
t = df.nunique().sort_values()
t[t < 200]

user_gender           2
tweet_region          4
topics_cluster_7      7
tweet_category       10
retweet_count        45
user_timezone        50
tweet_state          51
tweet_topics        115
dtype: int64

Let's start with the `PCA` (principle component analysis) projection of our data. 
The two dimensions of the PCA projection are the two directions of the original data that capture the most variance.

In [44]:
cosmo(
    df,
    point_label_by='tweet_text',
    point_x_by='pca_x',
    point_y_by='pca_y',
    point_color_by='tweet_category',
    point_size_scale=0.003,
)

Cosmograph(background_color=None, focused_point_ring_color=None, hovered_point_ring_color=None, link_color=Non…

Here, we can see that this PCA projection manages to distinguish the humor and health categories somewhat,
but the other categories are all over the place. This is not surprising, given that the PCA projection is
a linear transformation maximizing variance. 
If we look at other fields like the region the tweet was sent from, or the declared gender of the tweeter (only 2 -- this is 2014!),
we see that the PCA angle (first two components) doesn't offer a very informative view.

In [272]:
cosmo(
    df,
    point_label_by='tweet_text',
    point_x_by='pca_x',
    point_y_by='pca_y',
    point_color_by='tweet_region',
    point_size_scale=0.003,
)

Cosmograph(background_color=None, focused_point_ring_color=None, hovered_point_ring_color=None, link_color=Non…

In [45]:
cosmo(
    df,
    point_label_by='tweet_text',
    point_x_by='pca_x',
    point_y_by='pca_y',
    point_color_by='user_gender',
    point_size_scale=0.003,
)

Cosmograph(background_color=None, focused_point_ring_color=None, hovered_point_ring_color=None, link_color=Non…

Now let's have a look at more powerful, non-linear, transformations. 
TSNE and UMAP are two classics in this matter. 
Essentially, the way they work is by trying to preserve the local structure of the data,
which is very useful for visualization purposes, as it allows us to see clusters of data points that are similar to each other.

In [271]:
cosmo(
    df,
    point_label_by='tweet_text',
    point_x_by='tsne_x',
    point_y_by='tsne_y',
    point_color_by='tweet_category',
    point_size_scale=0.5,
)

Cosmograph(background_color=None, focused_point_ring_color=None, hovered_point_ring_color=None, link_color=Non…

In [48]:
cosmo(
    df,
    point_label_by='tweet_text',
    point_x_by='umap_x',
    point_y_by='umap_y',
    point_color_by='tweet_category',
    point_size_scale=0.03,
)

Cosmograph(background_color=None, focused_point_ring_color=None, hovered_point_ring_color=None, link_color=Non…

The separation of categories is definitely better with TSNE and UMAP than with PCA.
One could even convince oneself that the tweets of a given region and gender tend to clump together slightly more than with PCA, though there is still a lot of overlap.

In [52]:
cosmo(
    df,
    point_label_by='tweet_text',
    point_x_by='umap_x',
    point_y_by='umap_y',
    point_color_by='tweet_region',
    point_size_scale=0.03,
)

Cosmograph(background_color=None, focused_point_ring_color=None, hovered_point_ring_color=None, link_color=Non…

In [49]:
cosmo(
    df,
    point_label_by='tweet_text',
    point_x_by='tsne_x',
    point_y_by='tsne_y',
    point_color_by='user_gender',
    point_size_scale=0.5,
)

Cosmograph(background_color=None, focused_point_ring_color=None, hovered_point_ring_color=None, link_color=Non…

In [53]:
len(df)

4723

But where we'll really start to get something exciting to look at is when we start to supervise our projections with the target variable
we're trying to distinguish. 
Hold your horses, oh ye who know enough to critique, but not enough to be beyond blind dogma!
1500 dimensions and only 5000 points? Aren't we overfitting? Aren't we just telling the projection what we want to see?
Yes, we are. But I'm not going to humor that knee-jerk reaction with a response until we make it worse and get excited about the results.

Look at these beautiful tails for health & fitness, south of the figure, and finance stretching west on th figure. 
There's still a big pile of paella, but we've managed to really pull out some undeniable destinctions. 

[Explore this category angle further in the cosmograph app](https://cosmograph.fly.dev/public/408cafaa-59a6-4a6d-888a-0e2005e91e0e)

In [62]:
category_column = 'tweet_category'

cosmo(
    df,
    point_label_by='tweet_text',
    point_x_by=f'{category_column}_x',
    point_y_by=f'{category_column}_y',
    point_color_by=category_column,
    point_size_scale=0.03,
)

Cosmograph(background_color=None, focused_point_ring_color=None, hovered_point_ring_color=None, link_color=Non…

But let's look at the hard ones now. 

See what `tweet_region` looks like now? Though the midwest is appropriately in the middle of everything, the south, northest, and west definitely have some tweet angles of their own!

[Explore this region angle further in the cosmograph app](https://cosmograph.fly.dev/public/a9d47b81-31b4-4b3f-a23b-5132bcdc1682)

In [55]:
category_column = 'tweet_region'

cosmo(
    df,
    point_label_by='tweet_text',
    point_x_by=f'{category_column}_x',
    point_y_by=f'{category_column}_y',
    point_color_by=category_column,
    point_size_scale=0.02,
)

Cosmograph(background_color=None, focused_point_ring_color=None, hovered_point_ring_color=None, link_color=Non…

Even more surprisingly (to me), even there's even an angle through which `gender` seems to affect the semantics of tweets. 

[Explore this gender angle further in the cosmograph app](https://cosmograph.fly.dev/public/a9af81f4-e5f5-44ff-ad9f-35c3c4113ed1)

In [58]:
category_column = 'user_gender'

cosmo(
    df,
    point_label_by='tweet_text',
    point_x_by=f'{category_column}_x',
    point_y_by=f'{category_column}_y',
    point_color_by=category_column,
    point_size_scale=0.07,
)

Cosmograph(background_color=None, focused_point_ring_color=None, hovered_point_ring_color=None, link_color=Non…

Should we even dare to have a peep at `tweet_state` (51 unique values) and `tweet_topic` (115 unique values)?

Surely, we won't be able to get anything out of those, right?

In [59]:
category_col = 'tweet_state'

cosmo(
    df,
    point_label_by='tweet_text',
    point_x_by=f'{category_col}_x',
    point_y_by=f'{category_col}_y',
    point_color_by=category_col,
    point_size_scale=0.03,
)

Cosmograph(background_color=None, focused_point_ring_color=None, hovered_point_ring_color=None, link_color=Non…

In [61]:
category_col = 'tweet_topics'

cosmo(
    df,
    point_label_by='tweet_text',
    point_x_by=f'{category_col}_x',
    point_y_by=f'{category_col}_y',
    point_color_by=category_col,
    point_size_scale=0.03,
)

Cosmograph(background_color=None, focused_point_ring_color=None, hovered_point_ring_color=None, link_color=Non…

In [183]:
df.tweet_topics.value_counts()

tweet_topics
Other                                                    483
Be more positive                                         463
Improve my attitude                                      175
Humor about Personal Growth and Interests Resolutions    155
Humor about Health and Fitness Resolutions               155
                                                        ... 
Work less                                                  3
Start Master‰Ûªs program                                   3
Start waking up earlier                                    2
Fix up my home office                                      2
Learn to Cook                                              1
Name: count, Length: 115, dtype: int64

## The Curse of Dimensionality: A Fair Question

In high-dimensional spaces, overfitting looms large. With more dimensions than data points, it's dangerously easy to find projections that "tell a story" simply because they're tuned to the quirks of the data. Are we just finding angles that flatter our hypotheses?

Surprisingly, the answer often leans toward "no." Semantic embeddings aren't arbitrary - they're structured by design. The OpenAI model, like other transformer-based embeddings, organizes concepts into coherent clusters. Health resolutions, humor, and personal growth are not random blobs in this space - they're distinct regions.

Supervised methods like LDA don't invent these structures; they amplify them. By aligning with categories already latent in the embeddings, they uncover patterns that might otherwise stay hidden. The key is validation - testing these projections on unseen data to ensure the stories we tell aren't mere mirages.

That said, it’s important to remember that we’re not modeling here; we’re visualizing. While modeling and visualization are intricately connected, visualization has a different goal. It’s not primarily about predicting or estimating, so it’s less directly concerned with generalization. Instead, our objective is to offer multiple perspectives—different angles—on the data. By doing so, we empower the analyst to discover and explore diverse narratives, gaining a more complete understanding of the underlying realities the data reflects. Visualization is about inviting curiosity, sparking insights, and uncovering truths hidden in the multidimensional echo of the data.

# Appendix: Data Prep

In [1]:
import os
# USER: Set the save root directory here (this will be used to save data preparation artifacts)

save_rootdir = '~/Dropbox/_odata/figiri/new_years_resolutions/'


# assert that the directory exists
save_rootdir = os.path.abspath(os.path.expanduser(save_rootdir))
assert os.path.isdir(save_rootdir), f"The directory save_rootdir you specified does not exist: {save_rootdir}"

In [2]:
from tabled import DfFiles  # pip install tabled

df_files = DfFiles(save_rootdir)
list(df_files)  # list the files in the directory

['openai_text_embeddings.parquet',
 'with_hashtags/new_years_resolutions_prepped.parquet',
 'with_hashtags/various_scatter_plots.pdf',
 'openai_topics_embeddings.parquet',
 'openai_tweet_text_embeddings.parquet']

## Get source data

In [3]:
import pandas as pd

In [None]:
with_hashtags = True

if with_hashtags:
    prepped_data_file = 'with_hashtags/new_years_resolutions_prepped.parquet'
    embeddings_file = 'openai_tweet_text_embeddings'  # 'openai_tweet_text_embeddings' or 'openai_text_embeddings.parquet'
else:
    prepped_data_file = 'without_hashtags/new_years_resolutions_prepped.parquet'
    embeddings_file = 'openai_text_embeddings.parquet'

if prepped_data_file not in df_files:
    original_src_url = 'https://raw.githubusercontent.com/aj-menon/Maven-Analytics/refs/heads/master/2015_new_years_resolution_tweets/New_years_resolutions.csv'
    df = pd.read_csv(original_src_url)
    df_files[prepped_data_file] = df
    df['_id'] = df.index.values  # because some plotting libraries need a unique id as a column
else:
    df = df_files[prepped_data_file]

print(f"{df.shape}")
df.iloc[0]

(4723, 11)


tweet_created                                   2014-12-21 16:11:00
tweet_text        #NewYearsResolution to not put the parking lot...
tweet_category                                                Humor
tweet_topics      Humor about Personal Growth and Interests Reso...
tweet_location                                   City of Angels, CA
tweet_state                                                      CA
tweet_region                                                   West
user_timezone                            Pacific Time (US & Canada)
user_gender                                                    male
retweet_count                                                   NaN
_id                                                               0
Name: 0, dtype: object

## Poke around the data

In [5]:
df.tweet_text.nunique()

4723

In [6]:
print(f"The data spans from {df.tweet_created.min()} to {df.tweet_created.max()}")

The data spans from 2014-12-21 16:11:00 to 2015-01-02 09:54:00


What datas are missing?

In [7]:
t = df.notna().sum()
t[t!=len(df)]

user_timezone    3496
retweet_count    2932
dtype: int64

In [8]:
print(f"{df.retweet_count.isna().sum()} missing retweet counts")
print(f"{(df.retweet_count == 0).sum()} zero retweet counts")
print(f"{(df.retweet_count > 0).sum()} non-zero retweet counts")

1791 missing retweet counts
2215 zero retweet counts
717 non-zero retweet counts


What datas are categorical (or can be used as such in visualizations)?

First let's look at the number of unique values for each column

In [9]:
df.nunique()

tweet_created     2738
tweet_text        4723
tweet_category      10
tweet_topics       115
tweet_location    2630
tweet_state         51
tweet_region         4
user_timezone       50
user_gender          2
retweet_count       45
_id               4723
dtype: int64

In [10]:
max_uniques = 10
df.nunique()[df.nunique() <= max_uniques]

tweet_category    10
tweet_region       4
user_gender        2
dtype: int64

In [11]:
# How many (lower cased) tweet_text strings contain the subsetring "#newyearsresolution"?

df.tweet_text.str.lower().str.contains('#newyearsresolution').sum()


3748

## Compute embeddings of the tweet texts

### get clean text (remove hashtag words)

In [39]:
import re 

hashtag_pattern = re.compile(r"#\w+")

def remove_hashtags(text):
    return hashtag_pattern.sub('', text).strip()

df['text'] = df['tweet_text'].apply(remove_hashtags)

### tweet text

In [96]:
embeddings_file = 'openai_tweet_text_embeddings.parquet'

if embeddings_file not in df_files:
    from oa import embeddings

    vectors = embeddings(df.tweet_text)  # ~14s
    vectors_df = pd.DataFrame(data=vectors, index=df.index.values)
    df_files[embeddings_file] = vectors_df
else:
    vectors_df = df_files[embeddings_file]

### text

In [343]:
embeddings_file = 'openai_text_embeddings.parquet'

if embeddings_file not in df_files:
    from oa import embeddings

    vectors = embeddings(df.text)  # ~14s
    vectors_df = pd.DataFrame(data=vectors, index=df.index.values)
    df_files[embeddings_file] = vectors_df
else:
    vectors_df = df_files[embeddings_file]

### tweet topics

In [20]:
topics = df.tweet_topics.unique()
len(topics)

115

In [21]:
topics_save_file = 'openai_topics_embeddings.parquet'

if topics_save_file not in df_files:
    from oa import embeddings

    topics_vectors = embeddings(topics)  # ~14s
    topics_vectors_df = pd.DataFrame(data=topics_vectors, index=topics)
    df_files[topics_save_file] = topics_vectors_df
else:
    topics_vectors_df = df_files[topics_save_file]

## Compute topics (embeddings) clusters

In [22]:
from sklearn.cluster import KMeans

new_col = 'topics_cluster_7'

if new_col not in df.columns:
    kmeans = KMeans(n_clusters=7)
    cluster_indices = kmeans.fit_predict(topics_vectors_df.values)
    clusters_df = pd.DataFrame(
        data=cluster_indices, 
        index=topics_vectors_df.index, 
        columns=[new_col]
    )
    # join df to clusters_df, left on tweet_topics, right on index
    df = df.join(clusters_df, on='tweet_topics')
    df_files[prepped_data_file] = df

df.topics_cluster_7.value_counts()


topics_cluster_7
4    848
2    847
6    773
3    754
0    665
1    516
5    320
Name: count, dtype: int64

## 

## Compute planar projections of embeddings

### Unsupervised projection: pca, tsne and umap

In [37]:
embeddings_file = 'openai_tweet_text_embeddings'  # 'openai_tweet_text_embeddings' or 'openai_text_embeddings.parquet'

In [15]:
# Takes ~50s
from imbed import planar_embeddings, planar_embeddings_dict_to_df

vectors_df = df_files[embeddings_file]

if 'pca_x' not in df.columns:

    pca_xy = planar_embeddings(vectors_df.values, embeddings_func='pca')
    pca_xy = planar_embeddings_dict_to_df(pca_xy, x_col='pca_x', y_col='pca_y')

    df = pd.concat([df, pca_xy], axis=1)
    df_files[prepped_data_file] = df

if 'tsne_x' not in df.columns:

    tsne_xy = planar_embeddings(vectors_df.values, embeddings_func='tsne')
    tsne_xy = planar_embeddings_dict_to_df(tsne_xy, x_col='tsne_x', y_col='tsne_y')

    df = pd.concat([df, tsne_xy], axis=1)
    df_files[prepped_data_file] = df

if 'umap_x' not in df.columns:

    umap_xy = planar_embeddings(vectors_df.values, embeddings_func='umap')
    umap_xy = planar_embeddings_dict_to_df(umap_xy, x_col='umap_x', y_col='umap_y')

    df = pd.concat([df, umap_xy], axis=1)
    df_files[prepped_data_file] = df

df.iloc[0]

tweet_created                                   2014-12-21 16:11:00
tweet_text        #NewYearsResolution to not put the parking lot...
tweet_category                                                Humor
tweet_topics      Humor about Personal Growth and Interests Reso...
tweet_location                                   City of Angels, CA
tweet_state                                                      CA
tweet_region                                                   West
user_timezone                            Pacific Time (US & Canada)
user_gender                                                    male
retweet_count                                                   NaN
_id                                                               0
text              to not put the parking lot ticket directly in ...
pca_x                                                      0.196627
pca_y                                                     -0.210784
tsne_x                                          

### Linear Discriminant Analysis (supervised )

In [16]:
# see what columns have less than 200 unique values (canditates for categorical columns)
t = df.nunique()
t[t < 200]

tweet_category     10
tweet_topics      115
tweet_state        51
tweet_region        4
user_timezone      50
user_gender         2
retweet_count      45
dtype: int64

In [17]:
def add_lda_cols(df, vectors_df, category_col):
    from sklearn.discriminant_analysis import LinearDiscriminantAnalysis as LDA

    lda = LDA(n_components=2)
    lda_xy = lda.fit_transform(vectors_df.values, df[category_col].values)
    lda_xy = pd.DataFrame(data=lda_xy, columns=[f'{category_col}_x', f'{category_col}_y'])
    df[[f'{category_col}_x', f'{category_col}_y']] = lda_xy
    return df


In [23]:
if 'topics_cluster_7_x' not in df.columns:    
    df = add_lda_cols(df, vectors_df, category_col='tweet_topics')
    df = add_lda_cols(df, vectors_df, category_col='tweet_state')
    df = add_lda_cols(df, vectors_df, category_col='tweet_region')
    df = add_lda_cols(df, vectors_df, category_col='tweet_category')
    df = add_lda_cols(df, vectors_df, category_col='topics_cluster_7')

    df_files[prepped_data_file] = df

### Partial least squares

But what about the `user_gender`? Why didn’t we get an `LDA` projection for that? 
Because `user_gender` has only two possible values in this data, and _Linear Discriminant Analysis (LDA)_ can only produce projections with dimensions one less than the number of classes. This means it cannot create a 2D projection for binary categories. 

_Partial Least Squares (PLS)_, on the other hand, has no such limitation. PLS works by finding linear combinations of features that maximize the covariance between the input data and the target variable, allowing it to produce projections of any desired dimensionality.

In [24]:
from sklearn.cross_decomposition import PLSRegression

category_col = 'user_gender'
X, y = vectors_df.values, df[category_col].values

yy = pd.Categorical(y).codes  # ensure we have integer labels (PLS requires this)
yy = yy.reshape(-1, 1)  # Convert y to a 2D array (PLS requires y as a 2D array)
pls_xy = PLSRegression(n_components=2).fit(X, yy).transform(X)
pls_xy = pd.DataFrame(data=pls_xy, columns=[f'{category_col}_x', f'{category_col}_y'])
df[[f'{category_col}_x', f'{category_col}_y']] = pls_xy

df_files[prepped_data_file] = df


## Reorder columns

In [25]:
column_order = [
    'tweet_created',
    'tweet_text',
    'tweet_category',
    'tweet_topics',
    'tweet_location',
    'tweet_state',
    'tweet_region',
    'user_timezone',
    'user_gender',
    'retweet_count',
    'text',
    'topics_cluster_7',
    '_id',
    'pca_x',
    'pca_y',
    'tsne_x',
    'tsne_y',
    'umap_x',
    'umap_y',
    'tweet_topics_x',
    'tweet_topics_y',
    'tweet_state_x',
    'tweet_state_y',
    'topics_cluster_7_x',
    'topics_cluster_7_y',
    'tweet_region_x',
    'tweet_region_y',
    'tweet_category_x',
    'tweet_category_y',
    'user_gender_x',
    'user_gender_y',
]

In [26]:
df = df[column_order]
df_files[prepped_data_file] = df

## try multiple x/y/color_by combinations

Here we'll use seaborn to make a bunch of scatter plots with different combinations of projections and color field.

In [None]:
xy_fields = [
    ('pca_x', 'pca_y'),
    ('tsne_x', 'tsne_y'),
    ('umap_x', 'umap_y'),
]
category_cols = [
    'tweet_category',
    'tweet_region', 
    'user_gender',
    'topics_cluster_7', 
    'tweet_topics', 
    'tweet_state', 
]
supervised_combinations = [
    ('tweet_category_x', 'tweet_category_y', 'tweet_category'),
    ('tweet_region_x', 'tweet_region_y', 'tweet_region'),
    ('user_gender_x', 'user_gender_y', 'user_gender'),
    ('topics_cluster_7_x', 'topics_cluster_7_y', 'topics_cluster_7'),
    ('tweet_topics_x', 'tweet_topics_y', 'tweet_topics'),
    ('tweet_state_x', 'tweet_state_y', 'tweet_state'),
]

import itertools
from lkj import chunker

xy_category_fields = [tuple([*xy, category]) for xy, category in itertools.product(xy_fields, category_cols)]
xy_category_fields = list(xy_category_fields) + supervised_combinations

batch = next(chunker(xy_category_fields, 6))
batch

(('pca_x', 'pca_y', 'tweet_category'),
 ('pca_x', 'pca_y', 'tweet_state'),
 ('pca_x', 'pca_y', 'tweet_region'),
 ('pca_x', 'pca_y', 'user_gender'),
 ('pca_x', 'pca_y', 'topics_cluster_7'),
 ('pca_x', 'pca_y', 'tweet_topics'))

In [65]:
# make a 3x2 grid of plots, plotting each of the batch combinations as a scatter plot
# where each of the triples of the batch is (x_col, y_col, color_col)
# only use the first two categories of the color in the legend, and use elipses for the rest

import seaborn as sns
import matplotlib.pyplot as plt

def scatter_32(batch, include_legend=False):
    fig, axs = plt.subplots(3, 2, figsize=(12, 18))
    for i, (x_col, y_col, color_col) in enumerate(batch):
        ax = axs[i // 2, i % 2]
        sns.scatterplot(data=df, x=x_col, y=y_col, hue=color_col, ax=ax, alpha=0.5)
        ax.set_title(f'{x_col[:-2]}\ncolored by {color_col}')
        if not include_legend:
            ax.get_legend().remove()
        ax.set_xlabel('')
        ax.set_ylabel('')
        ax.set_xticks([])
        ax.set_yticks([])
    return fig

def pdf_bytes_of_batches():
    import io
    for batch in chunker(xy_category_fields, 6):
        fig = scatter_32(batch)
        pdf_bytes = io.BytesIO()
        fig.savefig(pdf_bytes, format='pdf')
        plt.close(fig)  # Close the current figure to suppress display
        yield pdf_bytes.getvalue()

from pdfdol import concat_pdfs

combined_pdf_bytes = concat_pdfs(pdf_bytes_of_batches())

# save these pdf bytes to a file in the save_rootdir
from dol import Files 

project_files = Files(save_rootdir)
project_files['various_scatter_plots.pdf'] = combined_pdf_bytes


In [63]:
# make a 2x3 grid of plots, plotting each of the batch combinations as a scatter plot
# where each of the triples of the batch is (x_col, y_col, color_col)
# only use the first two categories of the color in the legend, and use elipses for the rest

import seaborn as sns
import matplotlib.pyplot as plt

def scatter_23(batch, include_legend=False):
    fig, axs = plt.subplots(2, 3, figsize=(18, 12))  # Updated to 2 rows and 3 columns
    for i, (x_col, y_col, color_col) in enumerate(batch):
        ax = axs[i // 3, i % 3]  # Adjust indexing for a 2x3 grid
        sns.scatterplot(data=df, x=x_col, y=y_col, hue=color_col, ax=ax, alpha=0.5)
        ax.set_title(f'{x_col[:-2]}\ncolored by {color_col}')
        if not include_legend:
            ax.get_legend().remove()
        ax.set_xlabel('')
        ax.set_ylabel('')
        ax.set_xticks([])
        ax.set_yticks([])
    return fig

def pdf_bytes_of_batches():
    import io
    for batch in chunker(xy_category_fields, 6):
        fig = scatter_23(batch)
        pdf_bytes = io.BytesIO()
        fig.savefig(pdf_bytes, format='pdf')
        plt.close(fig)  # Close the current figure to suppress display
        yield pdf_bytes.getvalue()

from pdfdol import concat_pdfs

combined_pdf_bytes = concat_pdfs(pdf_bytes_of_batches())

# save these pdf bytes to a file in the save_rootdir
from dol import Files 

project_files = Files(save_rootdir)
project_files['various_scatter_plots_23.pdf'] = combined_pdf_bytes